From: yu.dongliang Date: Fri, 20 Nov 2020 05:09:44 +0000 (+0800) Subject: ok X-Git-Url: http://baseworks.info/?a=commitdiff_plain;h=e0ebbfe8e96cff1521238f1908cb7f6ce56ddcfc;p=mat.git ok --- diff --git a/Makefile b/Makefile index 92222ef..cf0669e 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,2 @@ all: - gcc -g -O3 mat.c + gcc -g -O3 mat.c -o mat_mul diff --git a/mat.c b/mat.c index 2b94027..8eacbfa 100644 --- a/mat.c +++ b/mat.c @@ -300,12 +300,25 @@ void mat_print(mat_t* m) int main(int argc, char* argv[]) { - if (argc < 2) { - printf("argc: %d, < 2\n", argc); + if (argc < 3) { + printf("./mat_mul n flag:\n"); + printf("n: nxn mat, n = 2^N, N > 0\n"); + printf("flag: 0 (normal), 1 (strassen), 2 (all & print)\n"); return -1; } - int n = atoi(argv[1]); + int n = atoi(argv[1]); + int flag = atoi(argv[2]); + + if (n < 2) { + printf("n must >= 2, n: %d\n", n); + return -1; + } + + if (n & (n - 1)) { + printf("n: %d, not 2^N\n", n); + return -1; + } srand(time(NULL)); @@ -317,14 +330,24 @@ int main(int argc, char* argv[]) mat_fill(m0); mat_fill(m1); - mat_mul(m2, m0, m1); - mat_mul_strassen(m3, m0, m1); -#if 1 - mat_print(m0); - mat_print(m1); - mat_print(m2); - mat_print(m3); -#endif + switch (flag) { + case 0: + mat_mul(m2, m0, m1); + break; + case 1: + mat_mul_strassen(m3, m0, m1); + break; + case 2: + mat_mul(m2, m0, m1); + mat_mul_strassen(m3, m0, m1); + + mat_print(m0); + mat_print(m1); + mat_print(m2); + mat_print(m3); + break; + }; + return 0; }