From 674015ea67412e5777fe98b421b79502f19046aa Mon Sep 17 00:00:00 2001 From: "yu.dongliang" <18588496441@163.com> Date: Mon, 2 Sep 2024 02:33:07 +0800 Subject: [PATCH] mat_trans --- mat.c | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/mat.c b/mat.c index c873b4d..91d572a 100644 --- a/mat.c +++ b/mat.c @@ -122,6 +122,25 @@ void mat_sub(mat_t* m, mat_t* m0, mat_t* m1) } } +void mat_trans(mat_t* t, mat_t* m) +{ + assert(m->n == t->n); + + int i; + int j; + int n = m->n; + + for (i = 0; i < n; i++) { + for (j = 0; j < n; j++) { + + int ij = (i + m->i) * m->b->n + (j + m->j); + int ji = (j + t->i) * t->b->n + (i + t->j); + + t->b->data[ji] = m->b->data[ij]; + } + } +} + void mat_mul(mat_t* m, mat_t* m0, mat_t* m1) { assert(m0->n == m1->n); @@ -132,16 +151,53 @@ void mat_mul(mat_t* m, mat_t* m0, mat_t* m1) int k; int n = m->n; + mat_t* t1 = mat_alloc(0, 0, n, NULL); + + mat_trans(t1, m1); + for (i = 0; i < n; i++) { for (j = 0; j < n; j++) { + int* d0 = m0->b->data + (i + m0->i) * m0->b->n + m0->j; + int* d1 = t1->b->data + (j + t1->i) * t1->b->n + t1->j; + int sum = 0; for (k = 0; k < n; k++) { + sum += d0[k] * d1[k]; + } + + int ij = (i + m->i) * m->b->n + (j + m->j); + + m->b->data[ij] = sum; + } + } - int ik = (i + m0->i) * m0->b->n + (k + m0->j); - int kj = (k + m1->i) * m1->b->n + (j + m1->j); + mat_free(t1); +} + +void mat_mul2(mat_t* m, mat_t* m0, mat_t* m1) +{ + assert(m0->n == m1->n); + assert(m->n == m0->n); + + int i; + int j; + int k; + int n = m->n; + + mat_t* t1 = mat_alloc(0, 0, n, NULL); - sum += m0->b->data[ik] * m1->b->data[kj]; + mat_trans(t1, m1); + + for (i = 0; i < n; i++) { + for (j = 0; j < n; j++) { + + int* d0 = m0->b->data + (i + m0->i) * m0->b->n + m0->j; + int* d1 = t1->b->data + (j + t1->i) * t1->b->n + t1->j; + + int sum = 0; + for (k = 0; k < n; k++) { + sum += d0[k] * d1[k]; } int ij = (i + m->i) * m->b->n + (j + m->j); @@ -149,6 +205,8 @@ void mat_mul(mat_t* m, mat_t* m0, mat_t* m1) m->b->data[ij] = sum; } } + + mat_free(t1); } void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1, int n_min) @@ -157,7 +215,7 @@ void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1, int n_min) assert(m->n == m0->n); #if 1 if (m->n <= n_min) - return mat_mul(m, m0, m1); + return mat_mul2(m, m0, m1); #endif if (n_min < 16) printf("%s(), m->n: %d, n_min: %d\n", __func__, m->n, n_min); @@ -355,9 +413,14 @@ int main(int argc, char* argv[]) mat_print(m2); mat_print(m3); break; + default: + printf("trans: \n"); + mat_trans(m1, m0); + mat_print(m0); + mat_print(m1); + break; }; printf("%s(), g_buf_size_max: %d\n", __func__, g_buf_size_max); return 0; } - -- 2.25.1