Revert "trans"
authoryu.dongliang <18588496441@163.com>
Sun, 1 Sep 2024 18:04:51 +0000 (02:04 +0800)
committeryu.dongliang <18588496441@163.com>
Sun, 1 Sep 2024 18:04:51 +0000 (02:04 +0800)
This reverts commit 8a82b4fe351dd8a8df401af62883152d202813fa.

mat.c

diff --git a/mat.c b/mat.c
index fde7bb39cbbbf135249fbec984942170a1463cfa..c873b4d76cf2568760cb755ca62bb3528c3d4c8c 100644 (file)
--- a/mat.c
+++ b/mat.c
@@ -101,25 +101,6 @@ void mat_add(mat_t* m, mat_t* m0, mat_t* m1)
        }
 }
 
-void mat_trans(mat_t* m)
-{
-       int i;
-       int j;
-       int n = m->n;
-
-       for (i = 0; i < n; i++) {
-               for (j = i + 1; j < n; j++) {
-
-                       int ij = (i + m->i) * m->b->n + (j + m->j);
-                       int ji = (j + m->i) * m->b->n + (i + m->j);
-
-                       int tmp = m->b->data[ij];
-                       m->b->data[ij] = m->b->data[ji];
-                       m->b->data[ji] = tmp;
-               }
-       }
-}
-
 void mat_sub(mat_t* m, mat_t* m0, mat_t* m1)
 {
        assert(m0->n == m1->n);
@@ -152,15 +133,16 @@ void mat_mul(mat_t* m, mat_t* m0, mat_t* m1)
        int n = m->n;
 
        for (i = 0; i < n; i++) {
-
                for (j = 0; j < n; j++) {
 
-                       int* d0 = m0->b->data +  (i + m0->i) * m0->b->n + m0->j;
-                       int* d1 = m1->b->data +  (i + m1->i) * m1->b->n + m1->j;
-
                        int sum = 0;
-                       for (k  = 0; k < n; k++)
-                               sum += d0[k] * d1[k];
+                       for (k  = 0; k < n; k++) {
+
+                               int ik = (i + m0->i) * m0->b->n + (k + m0->j);
+                               int kj = (k + m1->i) * m1->b->n + (j + m1->j);
+
+                               sum += m0->b->data[ik] * m1->b->data[kj];
+                       }
 
                        int ij  = (i + m->i) * m->b->n + (j + m->j);
 
@@ -328,7 +310,7 @@ int main(int argc, char* argv[])
        if (argc < 3) {
                printf("./mat_mul n flag:\n");
                printf("n: nxn mat, n = 2^N, N > 0\n");
-               printf("flag: 0 (normal), 1 (strassen), 2 (all & print), 3 (trans)\n");
+               printf("flag: 0 (normal), 1 (strassen), 2 (all & print)\n");
                return -1;
        }
 
@@ -372,14 +354,10 @@ int main(int argc, char* argv[])
                        mat_print(m1);
                        mat_print(m2);
                        mat_print(m3);
-               case 3:
-                       mat_print(m0);
-                       mat_trans(m0);
-                       printf("trans: \n");
-                       mat_print(m0);
                        break;
        };
 
        printf("%s(), g_buf_size_max: %d\n", __func__, g_buf_size_max);
        return 0;
 }
+