From: yu.dongliang <maja_creater@qq.com>
Date: Fri, 20 Nov 2020 05:34:33 +0000 (+0800)
Subject: ok
X-Git-Url: http://baseworks.info/?a=commitdiff_plain;h=4134d31070092100860fe58e3af40d9d103003bf;p=mat.git

ok
---

diff --git a/mat.c b/mat.c
index 8eacbfa..39784b1 100644
--- a/mat.c
+++ b/mat.c
@@ -139,14 +139,18 @@ void mat_mul(mat_t* m, mat_t* m0, mat_t* m1)
 	}
 }
 
-void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
+void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1, int n_min)
 {
 	assert(m0->n == m1->n);
 	assert(m->n  == m0->n);
 
-	int n = m->n / 2;
+	if (m->n <= n_min)
+		return mat_mul(m, m0, m1);
+
+	if (n_min < 16)
+		printf("%s(), n_min: %d\n", __func__, n_min);
 
-	printf("m->n: %d, n: %d\n", m->n, n);
+	int n = m->n / 2;
 
 	mat_t* a = mat_alloc(0, 0, n, m0->b);
 	mat_t* b = mat_alloc(0, n, n, m0->b);
@@ -173,63 +177,59 @@ void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1)
 
 	// tmp mat t0
 	mat_t* t0 = mat_alloc(0, 0, n, NULL);
-#if 1
+
 	// p1 = a * (f - h)
 	mat_sub(t0, f, h);
-	mat_mul(p1, a, t0);
+	mat_mul_strassen(p1, a, t0, n_min);
 
 	// p2 = (a + b) * h
 	mat_add(t0, a,  b);
-	mat_mul(p2, t0, h);
+	mat_mul_strassen(p2, t0, h, n_min);
 
 	// s = p1 + p2
 	mat_add(s, p1,  p2);
-#endif
 
-#if 1
+
 	// p3 = (c + d) * e
 	mat_add(t0, c,  d);
-	mat_mul(p3, t0, e);
+	mat_mul_strassen(p3, t0, e, n_min);
 
 	// p4 = d * (g - e)
 	mat_sub(t0, g,  e);
-	mat_mul(p4, d,  t0);
+	mat_mul_strassen(p4, d,  t0, n_min);
 
 	// t = p3 + p4
 	mat_add(t, p3,  p4);
-#endif
 
-#if 1
+
 	// tmp mat t1
 	mat_t* t1 = mat_alloc(0, 0, n, NULL);
 
 	//p5 = (a + d) * (e + h)
 	mat_add(t0, a,  d);
 	mat_add(t1, e,  h);
-	mat_mul(p5, t0, t1);
+	mat_mul_strassen(p5, t0, t1, n_min);
 
 	//p6 = (b - d) * (g + h)
 	mat_sub(t0, b,  d);
 	mat_add(t1, g,  h);
-	mat_mul(p6, t0, t1);
+	mat_mul_strassen(p6, t0, t1, n_min);
 
 	// r = p5 + p4 - p2 + p6
 	mat_add(r, p5,  p4);
 	mat_sub(r, r,   p2);
 	mat_add(r, r,   p6);
-#endif
 
-#if 1
+
 	//p7 = (a - c) * (e + f)
 	mat_sub(t0, a,  c);
 	mat_add(t1, e,  f);
-	mat_mul(p7, t0, t1);
+	mat_mul_strassen(p7, t0, t1, n_min);
 
 	// u = p5 + p1 -p3 -p7
 	mat_add(u, p5,  p1);
 	mat_sub(u, u,   p3);
 	mat_sub(u, u,   p7);
-#endif
 
 	// free unused mats
 	mat_free(t0);
@@ -335,11 +335,11 @@ int main(int argc, char* argv[])
 			mat_mul(m2, m0, m1);
 			break;
 		case 1:
-			mat_mul_strassen(m3, m0, m1);
+			mat_mul_strassen(m3, m0, m1, 64);
 			break;
 		case 2:
 			mat_mul(m2, m0, m1);
-			mat_mul_strassen(m3, m0, m1);
+			mat_mul_strassen(m3, m0, m1, 1);
 
 			mat_print(m0);
 			mat_print(m1);