}
}
+void mat_trans(mat_t* t, mat_t* m)
+{
+ assert(m->n == t->n);
+
+ int i;
+ int j;
+ int n = m->n;
+
+ for (i = 0; i < n; i++) {
+ for (j = 0; j < n; j++) {
+
+ int ij = (i + m->i) * m->b->n + (j + m->j);
+ int ji = (j + t->i) * t->b->n + (i + t->j);
+
+ t->b->data[ji] = m->b->data[ij];
+ }
+ }
+}
+
void mat_mul(mat_t* m, mat_t* m0, mat_t* m1)
{
assert(m0->n == m1->n);
int k;
int n = m->n;
+ mat_t* t1 = mat_alloc(0, 0, n, NULL);
+
+ mat_trans(t1, m1);
+
for (i = 0; i < n; i++) {
for (j = 0; j < n; j++) {
+ int* d0 = m0->b->data + (i + m0->i) * m0->b->n + m0->j;
+ int* d1 = t1->b->data + (j + t1->i) * t1->b->n + t1->j;
+
int sum = 0;
for (k = 0; k < n; k++) {
+ sum += d0[k] * d1[k];
+ }
+
+ int ij = (i + m->i) * m->b->n + (j + m->j);
+
+ m->b->data[ij] = sum;
+ }
+ }
- int ik = (i + m0->i) * m0->b->n + (k + m0->j);
- int kj = (k + m1->i) * m1->b->n + (j + m1->j);
+ mat_free(t1);
+}
+
+void mat_mul2(mat_t* m, mat_t* m0, mat_t* m1)
+{
+ assert(m0->n == m1->n);
+ assert(m->n == m0->n);
+
+ int i;
+ int j;
+ int k;
+ int n = m->n;
+
+ mat_t* t1 = mat_alloc(0, 0, n, NULL);
- sum += m0->b->data[ik] * m1->b->data[kj];
+ mat_trans(t1, m1);
+
+ for (i = 0; i < n; i++) {
+ for (j = 0; j < n; j++) {
+
+ int* d0 = m0->b->data + (i + m0->i) * m0->b->n + m0->j;
+ int* d1 = t1->b->data + (j + t1->i) * t1->b->n + t1->j;
+
+ int sum = 0;
+ for (k = 0; k < n; k++) {
+ sum += d0[k] * d1[k];
}
int ij = (i + m->i) * m->b->n + (j + m->j);
m->b->data[ij] = sum;
}
}
+
+ mat_free(t1);
}
void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1, int n_min)
assert(m->n == m0->n);
#if 1
if (m->n <= n_min)
- return mat_mul(m, m0, m1);
+ return mat_mul2(m, m0, m1);
#endif
if (n_min < 16)
printf("%s(), m->n: %d, n_min: %d\n", __func__, m->n, n_min);
mat_print(m2);
mat_print(m3);
break;
+ default:
+ printf("trans: \n");
+ mat_trans(m1, m0);
+ mat_print(m0);
+ mat_print(m1);
+ break;
};
printf("%s(), g_buf_size_max: %d\n", __func__, g_buf_size_max);
return 0;
}
-