}
}
-void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
+void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1, int n_min)
{
assert(m0->n == m1->n);
assert(m->n == m0->n);
- int n = m->n / 2;
+ if (m->n <= n_min)
+ return mat_mul(m, m0, m1);
+
+ if (n_min < 16)
+ printf("%s(), n_min: %d\n", __func__, n_min);
- printf("m->n: %d, n: %d\n", m->n, n);
+ int n = m->n / 2;
mat_t* a = mat_alloc(0, 0, n, m0->b);
mat_t* b = mat_alloc(0, n, n, m0->b);
// tmp mat t0
mat_t* t0 = mat_alloc(0, 0, n, NULL);
-#if 1
+
// p1 = a * (f - h)
mat_sub(t0, f, h);
- mat_mul(p1, a, t0);
+ mat_mul_strassen(p1, a, t0, n_min);
// p2 = (a + b) * h
mat_add(t0, a, b);
- mat_mul(p2, t0, h);
+ mat_mul_strassen(p2, t0, h, n_min);
// s = p1 + p2
mat_add(s, p1, p2);
-#endif
-#if 1
+
// p3 = (c + d) * e
mat_add(t0, c, d);
- mat_mul(p3, t0, e);
+ mat_mul_strassen(p3, t0, e, n_min);
// p4 = d * (g - e)
mat_sub(t0, g, e);
- mat_mul(p4, d, t0);
+ mat_mul_strassen(p4, d, t0, n_min);
// t = p3 + p4
mat_add(t, p3, p4);
-#endif
-#if 1
+
// tmp mat t1
mat_t* t1 = mat_alloc(0, 0, n, NULL);
//p5 = (a + d) * (e + h)
mat_add(t0, a, d);
mat_add(t1, e, h);
- mat_mul(p5, t0, t1);
+ mat_mul_strassen(p5, t0, t1, n_min);
//p6 = (b - d) * (g + h)
mat_sub(t0, b, d);
mat_add(t1, g, h);
- mat_mul(p6, t0, t1);
+ mat_mul_strassen(p6, t0, t1, n_min);
// r = p5 + p4 - p2 + p6
mat_add(r, p5, p4);
mat_sub(r, r, p2);
mat_add(r, r, p6);
-#endif
-#if 1
+
//p7 = (a - c) * (e + f)
mat_sub(t0, a, c);
mat_add(t1, e, f);
- mat_mul(p7, t0, t1);
+ mat_mul_strassen(p7, t0, t1, n_min);
// u = p5 + p1 -p3 -p7
mat_add(u, p5, p1);
mat_sub(u, u, p3);
mat_sub(u, u, p7);
-#endif
// free unused mats
mat_free(t0);
mat_mul(m2, m0, m1);
break;
case 1:
- mat_mul_strassen(m3, m0, m1);
+ mat_mul_strassen(m3, m0, m1, 64);
break;
case 2:
mat_mul(m2, m0, m1);
- mat_mul_strassen(m3, m0, m1);
+ mat_mul_strassen(m3, m0, m1, 1);
mat_print(m0);
mat_print(m1);