strassen
authoryu.dongliang <maja_creater@qq.com>
Fri, 20 Nov 2020 04:22:44 +0000 (12:22 +0800)
committeryu.dongliang <maja_creater@qq.com>
Fri, 20 Nov 2020 04:22:44 +0000 (12:22 +0800)
mat.c

diff --git a/mat.c b/mat.c
index f82a3f7c8b9d9693f564e669eeb7a7574890b8aa..ee6758c42c34bc805e62f018e915b81c067d06d3 100644 (file)
--- a/mat.c
+++ b/mat.c
@@ -53,6 +53,8 @@ mat_t* mat_alloc(int i, int j, int n, buf_t* b)
        if (!b) {
                b = buf_alloc(n);
                assert(b);
+       } else {
+               b->refs++;
        }
        m->b = b;
        return m;
@@ -137,6 +139,122 @@ void mat_mul(mat_t* m, mat_t* m0, mat_t* m1)
        }
 }
 
+void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
+{
+       assert(m0->n == m1->n);
+       assert(m->n  == m0->n);
+
+       int n = m->n / 2;
+
+       mat_t* a = mat_alloc(0, 0, n, m0->b);
+       mat_t* b = mat_alloc(0, n, n, m0->b);
+       mat_t* c = mat_alloc(n, 0, n, m0->b);
+       mat_t* d = mat_alloc(n, n, n, m0->b);
+
+       mat_t* e = mat_alloc(0, 0, n, m1->b);
+       mat_t* f = mat_alloc(0, n, n, m1->b);
+       mat_t* g = mat_alloc(n, 0, n, m1->b);
+       mat_t* h = mat_alloc(n, n, n, m1->b);
+
+       mat_t* r = mat_alloc(0, 0, n, m->b);
+       mat_t* s = mat_alloc(0, n, n, m->b);
+       mat_t* t = mat_alloc(n, 0, n, m->b);
+       mat_t* u = mat_alloc(n, n, n, m->b);
+
+       mat_t* p1 = mat_alloc(0, 0, n, NULL);
+       mat_t* p2 = mat_alloc(0, 0, n, NULL);
+       mat_t* p3 = mat_alloc(0, 0, n, NULL);
+       mat_t* p4 = mat_alloc(0, 0, n, NULL);
+       mat_t* p5 = mat_alloc(0, 0, n, NULL);
+       mat_t* p6 = mat_alloc(0, 0, n, NULL);
+       mat_t* p7 = mat_alloc(0, 0, n, NULL);
+
+       // tmp mat t0
+       mat_t* t0 = mat_alloc(0, 0, n, NULL);
+
+       // p1 = a * (f - h)
+       mat_sub(t0, f, h);
+       mat_mul(p1, a, t0);
+
+       // p2 = (a + b) * h
+       mat_add(t0, a,  b);
+       mat_mul(p2, t0, h);
+
+       // s = p1 + p2
+       mat_add(s, p1,  p2);
+
+       // p3 = (c + d) * e
+       mat_add(t0, c,  d);
+       mat_mul(p3, t0, e);
+
+       // p4 = d * (g - e)
+       mat_sub(t0, g,  e);
+       mat_mul(p3, d,  t0);
+
+       // t = p3 + p4
+       mat_add(t, p3,  p4);
+
+       // tmp mat t1
+       mat_t* t1 = mat_alloc(0, 0, n, NULL);
+
+       //p5 = (a + d) * (e + h)
+       mat_add(t0, a,  d);
+       mat_add(t1, e,  h);
+       mat_mul(p5, t0, t1);
+
+       //p6 = (b - d) * (g + h)
+       mat_sub(t0, b,  d);
+       mat_add(t1, g,  h);
+       mat_mul(p5, t0, t1);
+
+       // r = p5 + p4 - p2 + p6
+       mat_add(r, p5,  p4);
+       mat_sub(r, r,   p2);
+       mat_add(r, r,   p6);
+
+       //p7 = (a - c) * (e + f)
+       mat_sub(t0, a,  c);
+       mat_add(t1, e,  f);
+       mat_mul(p7, t0, t1);
+
+       // u = p5 + p1 -p3 -p7
+       mat_add(u, p5,  p1);
+       mat_sub(u, u,   p3);
+       mat_sub(u, u,   p7);
+
+       printf("t0->b->refs: %d\n", t0->b->refs);
+       printf("p1->b->refs: %d\n", p1->b->refs);
+       printf("a->b->refs: %d\n",  a->b->refs);
+       printf("e->b->refs: %d\n",  e->b->refs);
+       printf("r->b->refs: %d\n",  r->b->refs);
+       // free unused mats
+       mat_free(t0);
+       mat_free(t1);
+
+       mat_free(p1);
+       mat_free(p2);
+       mat_free(p3);
+       mat_free(p4);
+       mat_free(p5);
+       mat_free(p6);
+       mat_free(p7);
+
+       mat_free(a);
+       mat_free(b);
+       mat_free(c);
+       mat_free(d);
+
+       mat_free(e);
+       mat_free(f);
+       mat_free(g);
+       mat_free(h);
+
+       mat_free(r);
+       mat_free(s);
+       mat_free(t);
+       mat_free(u);
+}
+
 void mat_fill(mat_t* m)
 {
        assert(m && m->b);
@@ -189,16 +307,20 @@ int main(int argc, char* argv[])
 
        mat_t* m0 = mat_alloc(0, 0, n, NULL); 
        mat_t* m1 = mat_alloc(0, 0, n, NULL);
-       mat_t* m  = mat_alloc(0, 0, n, NULL);
+       mat_t* m2 = mat_alloc(0, 0, n, NULL);
+       mat_t* m3 = mat_alloc(0, 0, n, NULL);
 
        mat_fill(m0);
        mat_fill(m1);
 
-       mat_mul(m, m0, m1);
-
+       mat_mul(m2, m0, m1);
+       mat_mul_strassen(m3, m0, m1);
+#if 1
        mat_print(m0);
        mat_print(m1);
-       mat_print(m);
+       mat_print(m2);
+       mat_print(m3);
+#endif
        return 0;
 }