55
66def default_matrix_multiplication (a : list , b : list ) -> list :
77 """
8- Multiplication only for 2x2 matrices
8+ Standard multiplication for 2x2 matrices (base case).
9+
10+ Used as the base case for Strassen's algorithm when the matrix
11+ cannot be subdivided further. Uses 8 multiplications.
12+
13+ Time complexity: O(1) — fixed size input.
14+
15+ >>> default_matrix_multiplication([[1, 2], [3, 4]], [[5, 6], [7, 8]])
16+ [[19, 22], [43, 50]]
917 """
1018 if len (a ) != 2 or len (a [0 ]) != 2 or len (b ) != 2 or len (b [0 ]) != 2 :
1119 raise Exception ("Matrices are not 2x2" )
@@ -17,13 +25,15 @@ def default_matrix_multiplication(a: list, b: list) -> list:
1725
1826
1927def matrix_addition (matrix_a : list , matrix_b : list ):
28+ """Element-wise addition of two matrices of equal dimensions."""
2029 return [
2130 [matrix_a [row ][col ] + matrix_b [row ][col ] for col in range (len (matrix_a [row ]))]
2231 for row in range (len (matrix_a ))
2332 ]
2433
2534
2635def matrix_subtraction (matrix_a : list , matrix_b : list ):
36+ """Element-wise subtraction of two matrices of equal dimensions."""
2737 return [
2838 [matrix_a [row ][col ] - matrix_b [row ][col ] for col in range (len (matrix_a [row ]))]
2939 for row in range (len (matrix_a ))
@@ -64,6 +74,7 @@ def split_matrix(a: list) -> tuple[list, list, list, list]:
6474
6575
6676def matrix_dimensions (matrix : list ) -> tuple [int , int ]:
77+ """Return (rows, columns) of a matrix."""
6778 return len (matrix ), len (matrix [0 ])
6879
6980
@@ -73,8 +84,22 @@ def print_matrix(matrix: list) -> None:
7384
7485def actual_strassen (matrix_a : list , matrix_b : list ) -> list :
7586 """
76- Recursive function to calculate the product of two matrices, using the Strassen
77- Algorithm. It only supports square matrices of any size that is a power of 2.
87+ Recursive function to calculate the product of two matrices using Strassen's
88+ algorithm. Only supports square matrices whose dimensions are a power of 2.
89+
90+ Strassen's algorithm reduces matrix multiplication from 8 recursive
91+ multiplications (naive divide-and-conquer) to 7, at the cost of more
92+ additions and subtractions. This gives a better asymptotic complexity:
93+
94+ - Naive matrix multiplication: O(n^3)
95+ - Naive divide-and-conquer: O(n^3) — 8 multiplications of n/2 size
96+ - Strassen's algorithm: O(n^2.807) — 7 multiplications of n/2 size
97+
98+ The 7 intermediate products (t1-t7) are combined to form the four
99+ quadrants of the result matrix using only additions and subtractions.
100+
101+ Reference: Strassen, V. (1969). Gaussian elimination is not optimal.
102+ Numerische Mathematik, 13(4), 354-356.
78103 """
79104 if matrix_dimensions (matrix_a ) == (2 , 2 ):
80105 return default_matrix_multiplication (matrix_a , matrix_b )
@@ -106,6 +131,26 @@ def actual_strassen(matrix_a: list, matrix_b: list) -> list:
106131
107132def strassen (matrix1 : list , matrix2 : list ) -> list :
108133 """
134+ Multiply two matrices of arbitrary dimensions using Strassen's algorithm.
135+
136+ Handles non-square and non-power-of-2 matrices by padding with zeros
137+ to the next power of 2, running Strassen's algorithm, then removing
138+ the padding from the result.
139+
140+ Time complexity: O(n^2.807) where n is the padded dimension.
141+ Space complexity: O(n^2) for the padded matrices.
142+
143+ Args:
144+ matrix1: First matrix (m x n).
145+ matrix2: Second matrix (n x p). Number of columns in matrix1
146+ must equal number of rows in matrix2.
147+
148+ Returns:
149+ Result matrix (m x p).
150+
151+ Raises:
152+ Exception: If matrix dimensions are incompatible for multiplication.
153+
109154 >>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]])
110155 [[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]]
111156 >>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])
0 commit comments