From: yu.dongliang Date: Fri, 20 Nov 2020 05:34:33 +0000 (+0800) Subject: ok X-Git-Url: http://baseworks.info/?a=commitdiff_plain;h=4134d31070092100860fe58e3af40d9d103003bf;p=mat.git ok --- diff --git a/mat.c b/mat.c index 8eacbfa..39784b1 100644 --- a/mat.c +++ b/mat.c @@ -139,14 +139,18 @@ void mat_mul(mat_t* m, mat_t* m0, mat_t* m1) } } -void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1) +void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1, int n_min) { assert(m0->n == m1->n); assert(m->n == m0->n); - int n = m->n / 2; + if (m->n <= n_min) + return mat_mul(m, m0, m1); + + if (n_min < 16) + printf("%s(), n_min: %d\n", __func__, n_min); - printf("m->n: %d, n: %d\n", m->n, n); + int n = m->n / 2; mat_t* a = mat_alloc(0, 0, n, m0->b); mat_t* b = mat_alloc(0, n, n, m0->b); @@ -173,63 +177,59 @@ void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1) // tmp mat t0 mat_t* t0 = mat_alloc(0, 0, n, NULL); -#if 1 + // p1 = a * (f - h) mat_sub(t0, f, h); - mat_mul(p1, a, t0); + mat_mul_strassen(p1, a, t0, n_min); // p2 = (a + b) * h mat_add(t0, a, b); - mat_mul(p2, t0, h); + mat_mul_strassen(p2, t0, h, n_min); // s = p1 + p2 mat_add(s, p1, p2); -#endif -#if 1 + // p3 = (c + d) * e mat_add(t0, c, d); - mat_mul(p3, t0, e); + mat_mul_strassen(p3, t0, e, n_min); // p4 = d * (g - e) mat_sub(t0, g, e); - mat_mul(p4, d, t0); + mat_mul_strassen(p4, d, t0, n_min); // t = p3 + p4 mat_add(t, p3, p4); -#endif -#if 1 + // tmp mat t1 mat_t* t1 = mat_alloc(0, 0, n, NULL); //p5 = (a + d) * (e + h) mat_add(t0, a, d); mat_add(t1, e, h); - mat_mul(p5, t0, t1); + mat_mul_strassen(p5, t0, t1, n_min); //p6 = (b - d) * (g + h) mat_sub(t0, b, d); mat_add(t1, g, h); - mat_mul(p6, t0, t1); + mat_mul_strassen(p6, t0, t1, n_min); // r = p5 + p4 - p2 + p6 mat_add(r, p5, p4); mat_sub(r, r, p2); mat_add(r, r, p6); -#endif -#if 1 + //p7 = (a - c) * (e + f) mat_sub(t0, a, c); mat_add(t1, e, f); - mat_mul(p7, t0, t1); + mat_mul_strassen(p7, t0, t1, n_min); // u = p5 + p1 -p3 -p7 mat_add(u, p5, p1); mat_sub(u, u, p3); mat_sub(u, u, p7); -#endif // free unused mats mat_free(t0); @@ -335,11 +335,11 @@ int main(int argc, char* argv[]) mat_mul(m2, m0, m1); break; case 1: - mat_mul_strassen(m3, m0, m1); + mat_mul_strassen(m3, m0, m1, 64); break; case 2: mat_mul(m2, m0, m1); - mat_mul_strassen(m3, m0, m1); + mat_mul_strassen(m3, m0, m1, 1); mat_print(m0); mat_print(m1);