int main(int argc, char* argv[])
{
- if (argc < 2) {
- printf("argc: %d, < 2\n", argc);
+ if (argc < 3) {
+ printf("./mat_mul n flag:\n");
+ printf("n: nxn mat, n = 2^N, N > 0\n");
+ printf("flag: 0 (normal), 1 (strassen), 2 (all & print)\n");
return -1;
}
- int n = atoi(argv[1]);
+ int n = atoi(argv[1]);
+ int flag = atoi(argv[2]);
+
+ if (n < 2) {
+ printf("n must >= 2, n: %d\n", n);
+ return -1;
+ }
+
+ if (n & (n - 1)) {
+ printf("n: %d, not 2^N\n", n);
+ return -1;
+ }
srand(time(NULL));
mat_fill(m0);
mat_fill(m1);
- mat_mul(m2, m0, m1);
- mat_mul_strassen(m3, m0, m1);
-#if 1
- mat_print(m0);
- mat_print(m1);
- mat_print(m2);
- mat_print(m3);
-#endif
+ switch (flag) {
+ case 0:
+ mat_mul(m2, m0, m1);
+ break;
+ case 1:
+ mat_mul_strassen(m3, m0, m1);
+ break;
+ case 2:
+ mat_mul(m2, m0, m1);
+ mat_mul_strassen(m3, m0, m1);
+
+ mat_print(m0);
+ mat_print(m1);
+ mat_print(m2);
+ mat_print(m3);
+ break;
+ };
+
return 0;
}