mat_free(t1);
}
-void mat_mul2(mat_t* m, mat_t* m0, mat_t* m1)
-{
- assert(m0->n == m1->n);
- assert(m->n == m0->n);
-
- int i;
- int j;
- int k;
- int n = m->n;
-
- mat_t* t1 = mat_alloc(0, 0, n, NULL);
-
- mat_trans(t1, m1);
-
- for (i = 0; i < n; i++) {
- for (j = 0; j < n; j++) {
-
- int* d0 = m0->b->data + (i + m0->i) * m0->b->n + m0->j;
- int* d1 = t1->b->data + (j + t1->i) * t1->b->n + t1->j;
-
- int sum = 0;
- for (k = 0; k < n; k++) {
- sum += d0[k] * d1[k];
- }
-
- int ij = (i + m->i) * m->b->n + (j + m->j);
-
- m->b->data[ij] = sum;
- }
- }
-
- mat_free(t1);
-}
-
void mat_mul_strassen(mat_t* m, mat_t* m0, mat_t* m1, int n_min)
{
assert(m0->n == m1->n);
assert(m->n == m0->n);
#if 1
if (m->n <= n_min)
- return mat_mul2(m, m0, m1);
+ return mat_mul(m, m0, m1);
#endif
if (n_min < 16)
printf("%s(), m->n: %d, n_min: %d\n", __func__, m->n, n_min);