mat_trans
authoryu.dongliang <18588496441@163.com>
Sun, 1 Sep 2024 18:33:07 +0000 (02:33 +0800)
committeryu.dongliang <18588496441@163.com>
Sun, 1 Sep 2024 18:33:07 +0000 (02:33 +0800)
mat.c

diff --git a/mat.c b/mat.c
index c873b4d76cf2568760cb755ca62bb3528c3d4c8c..91d572a7373317c7a5d454125190c56f11be0ffa 100644 (file)
--- a/mat.c
+++ b/mat.c
@@ -122,6 +122,25 @@ void mat_sub(mat_t* m, mat_t* m0, mat_t* m1)
        }
 }
 
+void mat_trans(mat_t* t, mat_t* m)
+{
+       assert(m->n == t->n);
+
+       int i;
+       int j;
+       int n = m->n;
+
+       for (i = 0; i < n; i++) {
+               for (j = 0; j < n; j++) {
+
+                       int ij = (i + m->i) * m->b->n + (j + m->j);
+                       int ji = (j + t->i) * t->b->n + (i + t->j);
+
+                       t->b->data[ji] = m->b->data[ij];
+               }
+       }
+}
+
 void mat_mul(mat_t* m, mat_t* m0, mat_t* m1)
 {
        assert(m0->n == m1->n);
@@ -132,16 +151,53 @@ void mat_mul(mat_t* m, mat_t* m0, mat_t* m1)
        int k;
        int n = m->n;
 
+       mat_t* t1 = mat_alloc(0, 0, n, NULL);
+
+       mat_trans(t1, m1);
+
        for (i = 0; i < n; i++) {
                for (j = 0; j < n; j++) {
 
+                       int* d0 = m0->b->data + (i + m0->i) * m0->b->n + m0->j;
+                       int* d1 = t1->b->data + (j + t1->i) * t1->b->n + t1->j;
+
                        int sum = 0;
                        for (k  = 0; k < n; k++) {
+                               sum += d0[k] * d1[k];
+                       }
+
+                       int ij  = (i + m->i) * m->b->n + (j + m->j);
+
+                       m->b->data[ij] = sum;
+               }
+       }
 
-                               int ik = (i + m0->i) * m0->b->n + (k + m0->j);
-                               int kj = (k + m1->i) * m1->b->n + (j + m1->j);
+       mat_free(t1);
+}
+
+void mat_mul2(mat_t* m, mat_t* m0, mat_t* m1)
+{
+       assert(m0->n == m1->n);
+       assert(m->n  == m0->n);
+
+       int i;
+       int j;
+       int k;
+       int n = m->n;
+
+       mat_t* t1 = mat_alloc(0, 0, n, NULL);
 
-                               sum += m0->b->data[ik] * m1->b->data[kj];
+       mat_trans(t1, m1);
+
+       for (i = 0; i < n; i++) {
+               for (j = 0; j < n; j++) {
+
+                       int* d0 = m0->b->data + (i + m0->i) * m0->b->n + m0->j;
+                       int* d1 = t1->b->data + (j + t1->i) * t1->b->n + t1->j;
+
+                       int sum = 0;
+                       for (k  = 0; k < n; k++) {
+                               sum += d0[k] * d1[k];
                        }
 
                        int ij  = (i + m->i) * m->b->n + (j + m->j);
@@ -149,6 +205,8 @@ void mat_mul(mat_t* m, mat_t* m0, mat_t* m1)
                        m->b->data[ij] = sum;
                }
        }
+
+       mat_free(t1);
 }
 
 void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1, int n_min)
@@ -157,7 +215,7 @@ void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1, int n_min)
        assert(m->n  == m0->n);
 #if 1
        if (m->n <= n_min)
-               return mat_mul(m, m0, m1);
+               return mat_mul2(m, m0, m1);
 #endif
        if (n_min < 16)
                printf("%s(), m->n: %d, n_min: %d\n", __func__, m->n, n_min);
@@ -355,9 +413,14 @@ int main(int argc, char* argv[])
                        mat_print(m2);
                        mat_print(m3);
                        break;
+               default:
+                       printf("trans: \n");
+                       mat_trans(m1, m0);
+                       mat_print(m0);
+                       mat_print(m1);
+                       break;
        };
 
        printf("%s(), g_buf_size_max: %d\n", __func__, g_buf_size_max);
        return 0;
 }
-