ok
authoryu.dongliang <maja_creater@qq.com>
Fri, 20 Nov 2020 05:09:44 +0000 (13:09 +0800)
committeryu.dongliang <maja_creater@qq.com>
Fri, 20 Nov 2020 05:09:44 +0000 (13:09 +0800)
Makefile
mat.c

index 92222ef4fa266f8acb57ba344d552239656c6f02..cf0669e9c008e9205dc4a98de6cd22dc4cd113b9 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1,2 +1,2 @@
 all:
-       gcc -g -O3 mat.c
+       gcc -g -O3 mat.c -o mat_mul
diff --git a/mat.c b/mat.c
index 2b94027718a346b7a247ec219f66e32e3b75130a..8eacbfad814ba5a31879be44cc5f48037262cfad 100644 (file)
--- a/mat.c
+++ b/mat.c
@@ -300,12 +300,25 @@ void mat_print(mat_t* m)
 
 int main(int argc, char* argv[])
 {
-       if (argc < 2) {
-               printf("argc: %d, < 2\n", argc);
+       if (argc < 3) {
+               printf("./mat_mul n flag:\n");
+               printf("n: nxn mat, n = 2^N, N > 0\n");
+               printf("flag: 0 (normal), 1 (strassen), 2 (all & print)\n");
                return -1;
        }
 
-       int n = atoi(argv[1]);
+       int n    = atoi(argv[1]);
+       int flag = atoi(argv[2]);
+
+       if (n < 2) {
+               printf("n must >= 2, n: %d\n", n);
+               return -1;
+       }
+
+       if (n & (n - 1)) {
+               printf("n: %d, not 2^N\n", n);
+               return -1;
+       }
 
        srand(time(NULL));
 
@@ -317,14 +330,24 @@ int main(int argc, char* argv[])
        mat_fill(m0);
        mat_fill(m1);
 
-       mat_mul(m2, m0, m1);
-       mat_mul_strassen(m3, m0, m1);
-#if 1
-       mat_print(m0);
-       mat_print(m1);
-       mat_print(m2);
-       mat_print(m3);
-#endif
+       switch (flag) {
+               case 0:
+                       mat_mul(m2, m0, m1);
+                       break;
+               case 1:
+                       mat_mul_strassen(m3, m0, m1);
+                       break;
+               case 2:
+                       mat_mul(m2, m0, m1);
+                       mat_mul_strassen(m3, m0, m1);
+
+                       mat_print(m0);
+                       mat_print(m1);
+                       mat_print(m2);
+                       mat_print(m3);
+                       break;
+       };
+
        return 0;
 }