if (!b) {
b = buf_alloc(n);
assert(b);
+ } else {
+ b->refs++;
}
m->b = b;
return m;
}
}
+void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
+{
+ assert(m0->n == m1->n);
+ assert(m->n == m0->n);
+
+ int n = m->n / 2;
+
+ mat_t* a = mat_alloc(0, 0, n, m0->b);
+ mat_t* b = mat_alloc(0, n, n, m0->b);
+ mat_t* c = mat_alloc(n, 0, n, m0->b);
+ mat_t* d = mat_alloc(n, n, n, m0->b);
+
+ mat_t* e = mat_alloc(0, 0, n, m1->b);
+ mat_t* f = mat_alloc(0, n, n, m1->b);
+ mat_t* g = mat_alloc(n, 0, n, m1->b);
+ mat_t* h = mat_alloc(n, n, n, m1->b);
+
+ mat_t* r = mat_alloc(0, 0, n, m->b);
+ mat_t* s = mat_alloc(0, n, n, m->b);
+ mat_t* t = mat_alloc(n, 0, n, m->b);
+ mat_t* u = mat_alloc(n, n, n, m->b);
+
+ mat_t* p1 = mat_alloc(0, 0, n, NULL);
+ mat_t* p2 = mat_alloc(0, 0, n, NULL);
+ mat_t* p3 = mat_alloc(0, 0, n, NULL);
+ mat_t* p4 = mat_alloc(0, 0, n, NULL);
+ mat_t* p5 = mat_alloc(0, 0, n, NULL);
+ mat_t* p6 = mat_alloc(0, 0, n, NULL);
+ mat_t* p7 = mat_alloc(0, 0, n, NULL);
+
+ // tmp mat t0
+ mat_t* t0 = mat_alloc(0, 0, n, NULL);
+
+ // p1 = a * (f - h)
+ mat_sub(t0, f, h);
+ mat_mul(p1, a, t0);
+
+ // p2 = (a + b) * h
+ mat_add(t0, a, b);
+ mat_mul(p2, t0, h);
+
+ // s = p1 + p2
+ mat_add(s, p1, p2);
+
+ // p3 = (c + d) * e
+ mat_add(t0, c, d);
+ mat_mul(p3, t0, e);
+
+ // p4 = d * (g - e)
+ mat_sub(t0, g, e);
+ mat_mul(p3, d, t0);
+
+ // t = p3 + p4
+ mat_add(t, p3, p4);
+
+ // tmp mat t1
+ mat_t* t1 = mat_alloc(0, 0, n, NULL);
+
+ //p5 = (a + d) * (e + h)
+ mat_add(t0, a, d);
+ mat_add(t1, e, h);
+ mat_mul(p5, t0, t1);
+
+ //p6 = (b - d) * (g + h)
+ mat_sub(t0, b, d);
+ mat_add(t1, g, h);
+ mat_mul(p5, t0, t1);
+
+ // r = p5 + p4 - p2 + p6
+ mat_add(r, p5, p4);
+ mat_sub(r, r, p2);
+ mat_add(r, r, p6);
+
+ //p7 = (a - c) * (e + f)
+ mat_sub(t0, a, c);
+ mat_add(t1, e, f);
+ mat_mul(p7, t0, t1);
+
+ // u = p5 + p1 -p3 -p7
+ mat_add(u, p5, p1);
+ mat_sub(u, u, p3);
+ mat_sub(u, u, p7);
+
+ printf("t0->b->refs: %d\n", t0->b->refs);
+ printf("p1->b->refs: %d\n", p1->b->refs);
+ printf("a->b->refs: %d\n", a->b->refs);
+ printf("e->b->refs: %d\n", e->b->refs);
+ printf("r->b->refs: %d\n", r->b->refs);
+ // free unused mats
+ mat_free(t0);
+ mat_free(t1);
+
+ mat_free(p1);
+ mat_free(p2);
+ mat_free(p3);
+ mat_free(p4);
+ mat_free(p5);
+ mat_free(p6);
+ mat_free(p7);
+
+ mat_free(a);
+ mat_free(b);
+ mat_free(c);
+ mat_free(d);
+
+ mat_free(e);
+ mat_free(f);
+ mat_free(g);
+ mat_free(h);
+
+ mat_free(r);
+ mat_free(s);
+ mat_free(t);
+ mat_free(u);
+}
+
void mat_fill(mat_t* m)
{
assert(m && m->b);
mat_t* m0 = mat_alloc(0, 0, n, NULL);
mat_t* m1 = mat_alloc(0, 0, n, NULL);
- mat_t* m = mat_alloc(0, 0, n, NULL);
+ mat_t* m2 = mat_alloc(0, 0, n, NULL);
+ mat_t* m3 = mat_alloc(0, 0, n, NULL);
mat_fill(m0);
mat_fill(m1);
- mat_mul(m, m0, m1);
-
+ mat_mul(m2, m0, m1);
+ mat_mul_strassen(m3, m0, m1);
+#if 1
mat_print(m0);
mat_print(m1);
- mat_print(m);
+ mat_print(m2);
+ mat_print(m3);
+#endif
return 0;
}