ok
authoryu.dongliang <maja_creater@qq.com>
Fri, 20 Nov 2020 04:34:38 +0000 (12:34 +0800)
committeryu.dongliang <maja_creater@qq.com>
Fri, 20 Nov 2020 04:34:38 +0000 (12:34 +0800)
mat.c

diff --git a/mat.c b/mat.c
index ee6758c42c34bc805e62f018e915b81c067d06d3..2b94027718a346b7a247ec219f66e32e3b75130a 100644 (file)
--- a/mat.c
+++ b/mat.c
@@ -146,6 +146,8 @@ void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
 
        int n = m->n / 2;
 
+       printf("m->n: %d, n: %d\n", m->n, n);
+
        mat_t* a = mat_alloc(0, 0, n, m0->b);
        mat_t* b = mat_alloc(0, n, n, m0->b);
        mat_t* c = mat_alloc(n, 0, n, m0->b);
@@ -171,7 +173,7 @@ void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
 
        // tmp mat t0
        mat_t* t0 = mat_alloc(0, 0, n, NULL);
-
+#if 1
        // p1 = a * (f - h)
        mat_sub(t0, f, h);
        mat_mul(p1, a, t0);
@@ -182,18 +184,22 @@ void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
 
        // s = p1 + p2
        mat_add(s, p1,  p2);
+#endif
 
+#if 1
        // p3 = (c + d) * e
        mat_add(t0, c,  d);
        mat_mul(p3, t0, e);
 
        // p4 = d * (g - e)
        mat_sub(t0, g,  e);
-       mat_mul(p3, d,  t0);
+       mat_mul(p4, d,  t0);
 
        // t = p3 + p4
        mat_add(t, p3,  p4);
+#endif
 
+#if 1
        // tmp mat t1
        mat_t* t1 = mat_alloc(0, 0, n, NULL);
 
@@ -205,13 +211,15 @@ void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
        //p6 = (b - d) * (g + h)
        mat_sub(t0, b,  d);
        mat_add(t1, g,  h);
-       mat_mul(p5, t0, t1);
+       mat_mul(p6, t0, t1);
 
        // r = p5 + p4 - p2 + p6
        mat_add(r, p5,  p4);
        mat_sub(r, r,   p2);
        mat_add(r, r,   p6);
+#endif
 
+#if 1
        //p7 = (a - c) * (e + f)
        mat_sub(t0, a,  c);
        mat_add(t1, e,  f);
@@ -221,12 +229,8 @@ void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
        mat_add(u, p5,  p1);
        mat_sub(u, u,   p3);
        mat_sub(u, u,   p7);
+#endif
 
-       printf("t0->b->refs: %d\n", t0->b->refs);
-       printf("p1->b->refs: %d\n", p1->b->refs);
-       printf("a->b->refs: %d\n",  a->b->refs);
-       printf("e->b->refs: %d\n",  e->b->refs);
-       printf("r->b->refs: %d\n",  r->b->refs);
        // free unused mats
        mat_free(t0);
        mat_free(t1);