int main(int argc, char* argv[])
 {
-       if (argc < 2) {
-               printf("argc: %d, < 2\n", argc);
+       if (argc < 3) {
+               printf("./mat_mul n flag:\n");
+               printf("n: nxn mat, n = 2^N, N > 0\n");
+               printf("flag: 0 (normal), 1 (strassen), 2 (all & print)\n");
                return -1;
        }
 
-       int n = atoi(argv[1]);
+       int n    = atoi(argv[1]);
+       int flag = atoi(argv[2]);
+
+       if (n < 2) {
+               printf("n must >= 2, n: %d\n", n);
+               return -1;
+       }
+
+       if (n & (n - 1)) {
+               printf("n: %d, not 2^N\n", n);
+               return -1;
+       }
 
        srand(time(NULL));
 
        mat_fill(m0);
        mat_fill(m1);
 
-       mat_mul(m2, m0, m1);
-       mat_mul_strassen(m3, m0, m1);
-#if 1
-       mat_print(m0);
-       mat_print(m1);
-       mat_print(m2);
-       mat_print(m3);
-#endif
+       switch (flag) {
+               case 0:
+                       mat_mul(m2, m0, m1);
+                       break;
+               case 1:
+                       mat_mul_strassen(m3, m0, m1);
+                       break;
+               case 2:
+                       mat_mul(m2, m0, m1);
+                       mat_mul_strassen(m3, m0, m1);
+
+                       mat_print(m0);
+                       mat_print(m1);
+                       mat_print(m2);
+                       mat_print(m3);
+                       break;
+       };
+
        return 0;
 }