From: yu.dongliang Date: Fri, 20 Nov 2020 04:22:44 +0000 (+0800) Subject: strassen X-Git-Url: http://baseworks.info/?a=commitdiff_plain;h=96bf85567c6e86ddd1f021da1d379fc13fc9e714;p=mat.git strassen --- diff --git a/mat.c b/mat.c index f82a3f7..ee6758c 100644 --- a/mat.c +++ b/mat.c @@ -53,6 +53,8 @@ mat_t* mat_alloc(int i, int j, int n, buf_t* b) if (!b) { b = buf_alloc(n); assert(b); + } else { + b->refs++; } m->b = b; return m; @@ -137,6 +139,122 @@ void mat_mul(mat_t* m, mat_t* m0, mat_t* m1) } } +void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1) +{ + assert(m0->n == m1->n); + assert(m->n == m0->n); + + int n = m->n / 2; + + mat_t* a = mat_alloc(0, 0, n, m0->b); + mat_t* b = mat_alloc(0, n, n, m0->b); + mat_t* c = mat_alloc(n, 0, n, m0->b); + mat_t* d = mat_alloc(n, n, n, m0->b); + + mat_t* e = mat_alloc(0, 0, n, m1->b); + mat_t* f = mat_alloc(0, n, n, m1->b); + mat_t* g = mat_alloc(n, 0, n, m1->b); + mat_t* h = mat_alloc(n, n, n, m1->b); + + mat_t* r = mat_alloc(0, 0, n, m->b); + mat_t* s = mat_alloc(0, n, n, m->b); + mat_t* t = mat_alloc(n, 0, n, m->b); + mat_t* u = mat_alloc(n, n, n, m->b); + + mat_t* p1 = mat_alloc(0, 0, n, NULL); + mat_t* p2 = mat_alloc(0, 0, n, NULL); + mat_t* p3 = mat_alloc(0, 0, n, NULL); + mat_t* p4 = mat_alloc(0, 0, n, NULL); + mat_t* p5 = mat_alloc(0, 0, n, NULL); + mat_t* p6 = mat_alloc(0, 0, n, NULL); + mat_t* p7 = mat_alloc(0, 0, n, NULL); + + // tmp mat t0 + mat_t* t0 = mat_alloc(0, 0, n, NULL); + + // p1 = a * (f - h) + mat_sub(t0, f, h); + mat_mul(p1, a, t0); + + // p2 = (a + b) * h + mat_add(t0, a, b); + mat_mul(p2, t0, h); + + // s = p1 + p2 + mat_add(s, p1, p2); + + // p3 = (c + d) * e + mat_add(t0, c, d); + mat_mul(p3, t0, e); + + // p4 = d * (g - e) + mat_sub(t0, g, e); + mat_mul(p3, d, t0); + + // t = p3 + p4 + mat_add(t, p3, p4); + + // tmp mat t1 + mat_t* t1 = mat_alloc(0, 0, n, NULL); + + //p5 = (a + d) * (e + h) + mat_add(t0, a, d); + mat_add(t1, e, h); + mat_mul(p5, t0, t1); + + //p6 = (b - d) * (g + h) + mat_sub(t0, b, d); + mat_add(t1, g, h); + mat_mul(p5, t0, t1); + + // r = p5 + p4 - p2 + p6 + mat_add(r, p5, p4); + mat_sub(r, r, p2); + mat_add(r, r, p6); + + //p7 = (a - c) * (e + f) + mat_sub(t0, a, c); + mat_add(t1, e, f); + mat_mul(p7, t0, t1); + + // u = p5 + p1 -p3 -p7 + mat_add(u, p5, p1); + mat_sub(u, u, p3); + mat_sub(u, u, p7); + + printf("t0->b->refs: %d\n", t0->b->refs); + printf("p1->b->refs: %d\n", p1->b->refs); + printf("a->b->refs: %d\n", a->b->refs); + printf("e->b->refs: %d\n", e->b->refs); + printf("r->b->refs: %d\n", r->b->refs); + // free unused mats + mat_free(t0); + mat_free(t1); + + mat_free(p1); + mat_free(p2); + mat_free(p3); + mat_free(p4); + mat_free(p5); + mat_free(p6); + mat_free(p7); + + mat_free(a); + mat_free(b); + mat_free(c); + mat_free(d); + + mat_free(e); + mat_free(f); + mat_free(g); + mat_free(h); + + mat_free(r); + mat_free(s); + mat_free(t); + mat_free(u); +} + void mat_fill(mat_t* m) { assert(m && m->b); @@ -189,16 +307,20 @@ int main(int argc, char* argv[]) mat_t* m0 = mat_alloc(0, 0, n, NULL); mat_t* m1 = mat_alloc(0, 0, n, NULL); - mat_t* m = mat_alloc(0, 0, n, NULL); + mat_t* m2 = mat_alloc(0, 0, n, NULL); + mat_t* m3 = mat_alloc(0, 0, n, NULL); mat_fill(m0); mat_fill(m1); - mat_mul(m, m0, m1); - + mat_mul(m2, m0, m1); + mat_mul_strassen(m3, m0, m1); +#if 1 mat_print(m0); mat_print(m1); - mat_print(m); + mat_print(m2); + mat_print(m3); +#endif return 0; }