矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)

一般的矩阵乘法算法时间复杂度为

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(1)

1969年,Volker Strassen第一个提出了复杂度低于

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(2)

的矩阵乘法算法,算法时间复杂度为

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(3)

。Strassen算法证明了存在时间复杂度低于

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(4)

的算法。

假设矩阵 A 和矩阵 B 都是

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(5)

的方矩阵,求 C=AB ,如下所示:

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(6)

其中

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(7)

矩阵 C 可以通过下列公式求出:

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(8)

从上述公式我们可以得出,计算2个 n * n 的矩阵相乘需要2个

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(9)

的矩阵8次乘法和4次加法。我们使用 T (n) 表示 n * n 矩阵乘法的时间复杂度,那么我们可以根据上面的分解得到下面的递推公式:

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(10)

其中,

  1. 1.

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(11)

表示8次矩阵乘法,而且相乘的矩阵规模降到了

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(12)

  1. 2.

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(13)

表示4次矩阵加法的时间复杂度以及合并矩阵 C 的时间复杂度。

最终可计算得到

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(14)

现在,我们来看一下Strassen算法的原理。

仍然把每个矩阵分割为4份,然后创建如下10个中间矩阵:

S1 = B12 - B22S2 = A11 A12S3 = A21 A22S4 = B21 - B11S5 = A11 A22S6 = B11 B22S7 = A12 - A22S8 = B21 B22S9 = A11 - A21S10 = B11 B12

接着,计算7次矩阵乘法:

P1 = A11 • S1P2 = S2 • B22P3 = S3 • B11P4 = A22 • S4P5 = S5 • S6P6 = S7 • S8P7 = S9 • S10

最后,根据这7个结果就可以计算出C矩阵:

C11 = P5 P4 - P2 P6C12 = P1 P2C21 = P3 P4C22 = P5 P1 - P3 - P7

T(n) = 7T(n/2) Θ(n2)

使用递归树或主方法可以计算出结果:

T(n) = Θ(nlg7) ≈ Θ(n2.81)

下图展示了平凡算法和Strassen算法的性能差异,n越大,Strassen算法节约的时间越多。

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(15)

代码如下:

import java.util.Arrays;

public class MatrixMultiply {

public static void SquareMatrixMultiply(int A[][], int B[][]) {

int rows = A.length;

int C[][] = new int[rows][rows];

for (int i = 0; i < rows; i ) {

for (int j = 0; j < rows; j ) {

C[i][j] = 0;

for (int k = 0; k < rows; k ) {

C[i][j] = A[i][k] * B[k][j];

}

}

}

displaySquare(C);

}

public static void displaySquare(int matrix[][]) {

for (int i = 0; i < matrix.length; i ) {

for (int j : matrix[i]) {

System.out.print(j " ");

}

System.out.println();

}

}

public static void copyToMatrixArray(int srcMatrix[][], int startI, int startJ, int iLen, int jLen,

int destMatrix[][]) {

for (int i = startI; i < startI iLen; i ) {

for (int j = startJ; j < startJ jLen; j ) {

destMatrix[i - startI][j - startJ] = srcMatrix[i][j];

}

}

}

public static void copyFromMatrixArray(int destMatrix[][], int startI, int startJ, int iLen, int jLen,

int srcMatrix[][]) {

for (int i = 0; i < iLen; i ) {

for (int j = 0; j < jLen; j ) {

destMatrix[startI i][startJ j] = srcMatrix[i][j];

}

}

}

public static void squareMatrixAdd(int A[][], int B[][], int C[][]) {

for (int i = 0; i < A.length; i ) {

for (int j = 0; j < A[i].length; j ) {

C[i][j] = A[i][j] B[i][j];

}

}

}

public static void squareMatrixSub(int A[][], int B[][], int C[][]) {

for (int i = 0; i < A.length; i ) {

for (int j = 0; j < A[i].length; j ) {

C[i][j] = A[i][j] - B[i][j];

}

}

}

public static int[][] squareMatrixMultiplyRecursive(int A[][], int B[][]) {

int n = A.length;

int C[][] = new int[n][n];

if (n == 1) {

C[0][0] = A[0][0] * B[0][0];

} else {

int A11[][], A12[][], A21[][], A22[][];

int B11[][], B12[][], B21[][], B22[][];

int C11[][], C12[][], C21[][], C22[][];

A11 = new int[n/2][n/2];A12 = new int[n/2][n/2];A21 = new int[n/2][n/2];A22 = new int[n/2][n/2];

copyToMatrixArray(A, 0, 0, n/2, n/2, A11);

copyToMatrixArray(A, 0, n/2, n/2, n/2, A12);

copyToMatrixArray(A, n/2, 0, n/2, n/2, A21);

copyToMatrixArray(A, n/2, n/2, n/2, n/2, A22);

B11 = new int[n/2][n/2];B12 = new int[n/2][n/2];B21 = new int[n/2][n/2];B22 = new int[n/2][n/2];

copyToMatrixArray(B, 0, 0, n/2, n/2, B11);

copyToMatrixArray(B, 0, n/2, n/2, n/2, B12);

copyToMatrixArray(B, n/2, 0, n/2, n/2, B21);

copyToMatrixArray(B, n/2, n/2, n/2, n/2, B22);

C11 = new int[n/2][n/2];C12 = new int[n/2][n/2];C21 = new int[n/2][n/2];C22 = new int[n/2][n/2];

squareMatrixAdd(squareMatrixMultiplyRecursive(A11, B11), squareMatrixMultiplyRecursive(A12, B21),

C11);

squareMatrixAdd(squareMatrixMultiplyRecursive(A11, B12), squareMatrixMultiplyRecursive(A12, B22),

C12);

squareMatrixAdd(squareMatrixMultiplyRecursive(A21, B11), squareMatrixMultiplyRecursive(A22, B21),

C21);

squareMatrixAdd(squareMatrixMultiplyRecursive(A21, B12), squareMatrixMultiplyRecursive(A22, B22),

C22);

copyFromMatrixArray(C, 0, 0, n/2, n/2, C11);

copyFromMatrixArray(C, 0, n/2, n/2, n/2, C12);

copyFromMatrixArray(C, n/2, 0, n/2, n/2, C21);

copyFromMatrixArray(C, n/2, n/2, n/2, n/2, C22);

}

return C;

}

public static int[][] strassenMatrixMultiplyRecursive(int A[][], int B[][]) {

int n = A.length;

int C[][] = new int[n][n];

if (n == 1) {

C[0][0] = A[0][0] * B[0][0];

} else {

int A11[][], A12[][], A21[][], A22[][];

int B11[][], B12[][], B21[][], B22[][];

int C11[][], C12[][], C21[][], C22[][];

int S1[][], S2[][], S3[][], S4[][], S5[][], S6[][], S7[][], S8[][], S9[][], S10[][];

int P1[][], P2[][], P3[][], P4[][], P5[][], P6[][], P7[][];

A11 = new int[n/2][n/2];A12 = new int[n/2][n/2];A21 = new int[n/2][n/2];A22 = new int[n/2][n/2];

copyToMatrixArray(A, 0, 0, n/2, n/2, A11);

copyToMatrixArray(A, 0, n/2, n/2, n/2, A12);

copyToMatrixArray(A, n/2, 0, n/2, n/2, A21);

copyToMatrixArray(A, n/2, n/2, n/2, n/2, A22);

B11 = new int[n/2][n/2];B12 = new int[n/2][n/2];B21 = new int[n/2][n/2];B22 = new int[n/2][n/2];

copyToMatrixArray(B, 0, 0, n/2, n/2, B11);

copyToMatrixArray(B, 0, n/2, n/2, n/2, B12);

copyToMatrixArray(B, n/2, 0, n/2, n/2, B21);

copyToMatrixArray(B, n/2, n/2, n/2, n/2, B22);

S1 = new int[n/2][n/2];S2 = new int[n/2][n/2];S3 = new int[n/2][n/2];S4 = new int[n/2][n/2];

S5 = new int[n/2][n/2];S6 = new int[n/2][n/2];S7 = new int[n/2][n/2];S8 = new int[n/2][n/2];

S9 = new int[n/2][n/2];S10 = new int[n/2][n/2];

squareMatrixSub(B12, B22, S1);squareMatrixAdd(A11, A12, S2);squareMatrixAdd(A21, A22, S3);

squareMatrixSub(B21, B11, S4);squareMatrixAdd(A11, A22, S5);squareMatrixAdd(B11, B22, S6);

squareMatrixSub(A12, A22, S7);squareMatrixAdd(B21, B22, S8);squareMatrixSub(A11, A21, S9);

squareMatrixAdd(B11, B12, S10);

P1 = new int[n/2][n/2];P2 = new int[n/2][n/2];P3 = new int[n/2][n/2];P4 = new int[n/2][n/2];

P5 = new int[n/2][n/2];P6 = new int[n/2][n/2];P7 = new int[n/2][n/2];

P1 = strassenMatrixMultiplyRecursive(A11, S1);

P2 = strassenMatrixMultiplyRecursive(S2, B22);

P3 = strassenMatrixMultiplyRecursive(S3, B11);

P4 = strassenMatrixMultiplyRecursive(A22, S4);

P5 = strassenMatrixMultiplyRecursive(S5, S6);

P6 = strassenMatrixMultiplyRecursive(S7, S8);

P7 = strassenMatrixMultiplyRecursive(S9, S10);

C11 = new int[n/2][n/2];C12 = new int[n/2][n/2];C21 = new int[n/2][n/2];C22 = new int[n/2][n/2];

int temp[][] = new int[n/2][n/2];

squareMatrixAdd(P5, P4, temp);

squareMatrixSub(temp, P2, temp);

squareMatrixAdd(temp, P6, C11);

squareMatrixAdd(P1, P2, C12);

squareMatrixAdd(P3, P4, C21);

squareMatrixAdd(P5, P1, temp);

squareMatrixSub(temp, P3, temp);

squareMatrixSub(temp, P7, C22);

copyFromMatrixArray(C, 0, 0, n/2, n/2, C11);

copyFromMatrixArray(C, 0, n/2, n/2, n/2, C12);

copyFromMatrixArray(C, n/2, 0, n/2, n/2, C21);

copyFromMatrixArray(C, n/2, n/2, n/2, n/2, C22);

}

return C;

}

public static int sMatrixA[][] = new int[][] {

{1, 2, 3, 4, 5, 6, 7, 8},

{1, 2, 3, 4, 5, 6, 7, 8},

{1, 2, 3, 4, 5, 6, 7, 8},

{1, 2, 3, 4, 5, 6, 7, 8},

{1, 2, 3, 4, 5, 6, 7, 8},

{1, 2, 3, 4, 5, 6, 7, 8},

{1, 2, 3, 4, 5, 6, 7, 8},

{1, 2, 3, 4, 5, 6, 7, 8},

};

public static int sMatrixB[][] = new int[][] {

{5, 6, 7, 8, 1, 2, 3, 4},

{5, 6, 7, 8, 1, 2, 3, 4},

{5, 6, 7, 8, 1, 2, 3, 4},

{5, 6, 7, 8, 1, 2, 3, 4},

{5, 6, 7, 8, 1, 2, 3, 4},

{5, 6, 7, 8, 1, 2, 3, 4},

{5, 6, 7, 8, 1, 2, 3, 4},

{5, 6, 7, 8, 1, 2, 3, 4},

};

public static void main(String[] args) {

System.out.println("普通矩阵乘法");

SquareMatrixMultiply(sMatrixA, sMatrixB);

System.out.println("\n递归矩阵乘法");

int C[][] = squareMatrixMultiplyRecursive(sMatrixA, sMatrixB);

displaySquare(C);

System.out.println("\n Strassen 递归矩阵乘法");

C = strassenMatrixMultiplyRecursive(sMatrixA, sMatrixB);

displaySquare(C);

}

}

矩阵乘法算法优化(算法之2矩阵乘法的Strassen算法)(16)

注:凡属于本公众号内容,未经允许不得私自转载,否则将依法追究侵权责任。

,

免责声明:本文仅代表文章作者的个人观点,与本站无关。其原创性、真实性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容文字的真实性、完整性和原创性本站不作任何保证或承诺,请读者仅作参考,并自行核实相关内容。文章投诉邮箱:anhduc.ph@yahoo.com

    分享
    投诉
    首页