Skip to content

Commit fc2f628

Browse files
docs: improve Strassen's algorithm docstrings with complexity analysis
- Add O(n^2.807) time complexity explanation to actual_strassen() - Explain why 7 multiplications beats naive 8 multiplications - Add docstring to strassen() with args, returns, and raises - Add doctest to default_matrix_multiplication() - Add one-line docstrings to helper functions Ref #14084
1 parent 02680c9 commit fc2f628

1 file changed

Lines changed: 48 additions & 3 deletions

File tree

divide_and_conquer/strassen_matrix_multiplication.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55

66
def default_matrix_multiplication(a: list, b: list) -> list:
77
"""
8-
Multiplication only for 2x2 matrices
8+
Standard multiplication for 2x2 matrices (base case).
9+
10+
Used as the base case for Strassen's algorithm when the matrix
11+
cannot be subdivided further. Uses 8 multiplications.
12+
13+
Time complexity: O(1) — fixed size input.
14+
15+
>>> default_matrix_multiplication([[1, 2], [3, 4]], [[5, 6], [7, 8]])
16+
[[19, 22], [43, 50]]
917
"""
1018
if len(a) != 2 or len(a[0]) != 2 or len(b) != 2 or len(b[0]) != 2:
1119
raise Exception("Matrices are not 2x2")
@@ -17,13 +25,15 @@ def default_matrix_multiplication(a: list, b: list) -> list:
1725

1826

1927
def matrix_addition(matrix_a: list, matrix_b: list):
28+
"""Element-wise addition of two matrices of equal dimensions."""
2029
return [
2130
[matrix_a[row][col] + matrix_b[row][col] for col in range(len(matrix_a[row]))]
2231
for row in range(len(matrix_a))
2332
]
2433

2534

2635
def matrix_subtraction(matrix_a: list, matrix_b: list):
36+
"""Element-wise subtraction of two matrices of equal dimensions."""
2737
return [
2838
[matrix_a[row][col] - matrix_b[row][col] for col in range(len(matrix_a[row]))]
2939
for row in range(len(matrix_a))
@@ -64,6 +74,7 @@ def split_matrix(a: list) -> tuple[list, list, list, list]:
6474

6575

6676
def matrix_dimensions(matrix: list) -> tuple[int, int]:
77+
"""Return (rows, columns) of a matrix."""
6778
return len(matrix), len(matrix[0])
6879

6980

@@ -73,8 +84,22 @@ def print_matrix(matrix: list) -> None:
7384

7485
def actual_strassen(matrix_a: list, matrix_b: list) -> list:
7586
"""
76-
Recursive function to calculate the product of two matrices, using the Strassen
77-
Algorithm. It only supports square matrices of any size that is a power of 2.
87+
Recursive function to calculate the product of two matrices using Strassen's
88+
algorithm. Only supports square matrices whose dimensions are a power of 2.
89+
90+
Strassen's algorithm reduces matrix multiplication from 8 recursive
91+
multiplications (naive divide-and-conquer) to 7, at the cost of more
92+
additions and subtractions. This gives a better asymptotic complexity:
93+
94+
- Naive matrix multiplication: O(n^3)
95+
- Naive divide-and-conquer: O(n^3) — 8 multiplications of n/2 size
96+
- Strassen's algorithm: O(n^2.807) — 7 multiplications of n/2 size
97+
98+
The 7 intermediate products (t1-t7) are combined to form the four
99+
quadrants of the result matrix using only additions and subtractions.
100+
101+
Reference: Strassen, V. (1969). Gaussian elimination is not optimal.
102+
Numerische Mathematik, 13(4), 354-356.
78103
"""
79104
if matrix_dimensions(matrix_a) == (2, 2):
80105
return default_matrix_multiplication(matrix_a, matrix_b)
@@ -106,6 +131,26 @@ def actual_strassen(matrix_a: list, matrix_b: list) -> list:
106131

107132
def strassen(matrix1: list, matrix2: list) -> list:
108133
"""
134+
Multiply two matrices of arbitrary dimensions using Strassen's algorithm.
135+
136+
Handles non-square and non-power-of-2 matrices by padding with zeros
137+
to the next power of 2, running Strassen's algorithm, then removing
138+
the padding from the result.
139+
140+
Time complexity: O(n^2.807) where n is the padded dimension.
141+
Space complexity: O(n^2) for the padded matrices.
142+
143+
Args:
144+
matrix1: First matrix (m x n).
145+
matrix2: Second matrix (n x p). Number of columns in matrix1
146+
must equal number of rows in matrix2.
147+
148+
Returns:
149+
Result matrix (m x p).
150+
151+
Raises:
152+
Exception: If matrix dimensions are incompatible for multiplication.
153+
109154
>>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]])
110155
[[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]]
111156
>>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])

0 commit comments

Comments
 (0)