Skip to content

Commit

Permalink
Use symmetry when squaring in gr_mat_mul_strassen (flintlib#2096)
Browse files Browse the repository at this point in the history
* use symmetry when squaring in gr_mat_mul_strassen

* rm duplicated include
  • Loading branch information
fredrik-johansson authored Nov 8, 2024
1 parent b197a2e commit ed0f2b0
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 111 deletions.
322 changes: 215 additions & 107 deletions src/gr_mat/mul_strassen.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@

#include "gr_mat.h"

#include "fmpz_mat.h"

/* todo: optimize for small matrices */
/* todo: bodrato squaring */
/* todo: use fused add-mul operations when supported by
the matrix interface in the future */
/* todo: when squaring, pretransform A12, A21, X2 which are
used twice in the recursive multiplications */

/* The implemented sequence is not Strassen's nor Winograd's, but the sequence
proposed by Bodrato, which is equivalent to Winograd's, and can be easily
Expand All @@ -27,14 +26,8 @@
int gr_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
{
slong ar, ac, br, bc;
slong anr, anc, bnr, bnc;
int status = GR_SUCCESS;

gr_mat_t A11, A12, A21, A22;
gr_mat_t B11, B12, B21, B22;
gr_mat_t C11, C12, C21, C22;
gr_mat_t X1, X2;

ar = A->r;
ac = A->c;
br = B->r;
Expand All @@ -60,132 +53,247 @@ int gr_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t
return status;
}

anr = ar / 2;
anc = ac / 2;
bnr = anc;
bnc = bc / 2;
if (A == B)
{
slong anr;

gr_mat_window_init(A11, A, 0, 0, anr, anc, ctx);
gr_mat_window_init(A12, A, 0, anc, anr, 2 * anc, ctx);
gr_mat_window_init(A21, A, anr, 0, 2 * anr, anc, ctx);
gr_mat_window_init(A22, A, anr, anc, 2 * anr, 2 * anc, ctx);
gr_mat_t A11, A12, A21, A22;
gr_mat_t C11, C12, C21, C22;
gr_mat_t X1, X2;

gr_mat_window_init(B11, B, 0, 0, bnr, bnc, ctx);
gr_mat_window_init(B12, B, 0, bnc, bnr, 2 * bnc, ctx);
gr_mat_window_init(B21, B, bnr, 0, 2 * bnr, bnc, ctx);
gr_mat_window_init(B22, B, bnr, bnc, 2 * bnr, 2 * bnc, ctx);
anr = ar / 2;

gr_mat_window_init(C11, C, 0, 0, anr, bnc, ctx);
gr_mat_window_init(C12, C, 0, bnc, anr, 2 * bnc, ctx);
gr_mat_window_init(C21, C, anr, 0, 2 * anr, bnc, ctx);
gr_mat_window_init(C22, C, anr, bnc, 2 * anr, 2 * bnc, ctx);
gr_mat_window_init(A11, A, 0, 0, anr, anr, ctx);
gr_mat_window_init(A12, A, 0, anr, anr, 2 * anr, ctx);
gr_mat_window_init(A21, A, anr, 0, 2 * anr, anr, ctx);
gr_mat_window_init(A22, A, anr, anr, 2 * anr, 2 * anr, ctx);

gr_mat_init(X1, anr, FLINT_MAX(bnc, anc), ctx);
gr_mat_init(X2, anc, bnc, ctx);
gr_mat_window_init(C11, C, 0, 0, anr, anr, ctx);
gr_mat_window_init(C12, C, 0, anr, anr, 2 * anr, ctx);
gr_mat_window_init(C21, C, anr, 0, 2 * anr, anr, ctx);
gr_mat_window_init(C22, C, anr, anr, 2 * anr, 2 * anr, ctx);

X1->c = anc;
gr_mat_init(X2, anr, anr, ctx);

status |= gr_mat_add(X1, A22, A12, ctx);
status |= gr_mat_add(X2, B22, B12, ctx);
status |= gr_mat_mul(C21, X1, X2, ctx);
status |= gr_mat_add(X2, A22, A12, ctx);
status |= gr_mat_mul(C21, X2, X2, ctx);
status |= gr_mat_sub(X2, A22, A21, ctx);
status |= gr_mat_mul(C22, X2, X2, ctx);
status |= gr_mat_add(X2, X2, A12, ctx);
status |= gr_mat_mul(C11, X2, X2, ctx);

status |= gr_mat_sub(X1, A22, A21, ctx);
status |= gr_mat_sub(X2, B22, B21, ctx);
status |= gr_mat_mul(C22, X1, X2, ctx);
status |= gr_mat_sub(X2, X2, A11, ctx);
status |= gr_mat_mul(C12, X2, A12, ctx);

status |= gr_mat_add(X1, X1, A12, ctx);
status |= gr_mat_add(X2, X2, B12, ctx);
status |= gr_mat_mul(C11, X1, X2, ctx);
gr_mat_init(X1, anr, anr, ctx);

status |= gr_mat_sub(X1, X1, A11, ctx);
status |= gr_mat_mul(C12, X1, B12, ctx);
status |= gr_mat_mul(X1, A12, A21, ctx);
status |= gr_mat_add(C11, C11, X1, ctx);
status |= gr_mat_sub(C12, C11, C12, ctx);
status |= gr_mat_sub(C11, C21, C11, ctx);
status |= gr_mat_mul(C21, A21, X2, ctx);

X1->c = bnc;
status |= gr_mat_mul(X1, A12, B21, ctx);
status |= gr_mat_add(C11, C11, X1, ctx);
status |= gr_mat_add(C12, C12, C22, ctx);
status |= gr_mat_sub(C12, C11, C12, ctx);
status |= gr_mat_sub(C11, C21, C11, ctx);
status |= gr_mat_sub(X2, X2, B11, ctx);
status |= gr_mat_mul(C21, A21, X2, ctx);
gr_mat_clear(X2, ctx);

gr_mat_clear(X2, ctx);
status |= gr_mat_sub(C21, C11, C21, ctx);
status |= gr_mat_sub(C12, C12, C22, ctx);
status |= gr_mat_add(C22, C22, C11, ctx);
status |= gr_mat_mul(C11, A11, A11, ctx);
status |= gr_mat_add(C11, X1, C11, ctx);

status |= gr_mat_sub(C21, C11, C21, ctx);
status |= gr_mat_add(C22, C22, C11, ctx);
status |= gr_mat_mul(C11, A11, B11, ctx);
gr_mat_clear(X1, ctx);

status |= gr_mat_add(C11, X1, C11, ctx);
gr_mat_window_clear(A11, ctx);
gr_mat_window_clear(A12, ctx);
gr_mat_window_clear(A21, ctx);
gr_mat_window_clear(A22, ctx);

X1->c = FLINT_MAX(bnc, anc);
gr_mat_clear(X1, ctx);
gr_mat_window_clear(C11, ctx);
gr_mat_window_clear(C12, ctx);
gr_mat_window_clear(C21, ctx);
gr_mat_window_clear(C22, ctx);

gr_mat_window_clear(A11, ctx);
gr_mat_window_clear(A12, ctx);
gr_mat_window_clear(A21, ctx);
gr_mat_window_clear(A22, ctx);
if (ar > 2 * anr)
{
{
gr_mat_t Ac, Cc;
gr_mat_window_init(Ac, A, 0, 2 * anr, ar, ar, ctx);
gr_mat_window_init(Cc, C, 0, 2 * anr, ar, ar, ctx);

gr_mat_window_clear(B11, ctx);
gr_mat_window_clear(B12, ctx);
gr_mat_window_clear(B21, ctx);
gr_mat_window_clear(B22, ctx);
status |= gr_mat_mul(Cc, A, Ac, ctx);

gr_mat_window_clear(C11, ctx);
gr_mat_window_clear(C12, ctx);
gr_mat_window_clear(C21, ctx);
gr_mat_window_clear(C22, ctx);
gr_mat_window_clear(Ac, ctx);
gr_mat_window_clear(Cc, ctx);
}

if (bc > 2 * bnc)
{
gr_mat_t Bc, Cc;
gr_mat_window_init(Bc, B, 0, 2 * bnc, ac, bc, ctx);
gr_mat_window_init(Cc, C, 0, 2 * bnc, ar, bc, ctx);
status |= gr_mat_mul(Cc, A, Bc, ctx);
gr_mat_window_clear(Bc, ctx);
gr_mat_window_clear(Cc, ctx);
}
{
gr_mat_t Ar, Cr;
gr_mat_t As;

gr_mat_window_init(Ar, A, 2 * anr, 0, ar, ar, ctx);
gr_mat_window_init(Cr, C, 2 * anr, 0, ar, 2 * anr, ctx);
gr_mat_window_init(As, A, 0, 0, ar, 2 * anr, ctx);

status |= gr_mat_mul(Cr, Ar, As, ctx);

gr_mat_window_clear(As, ctx);
gr_mat_window_clear(Ar, ctx);
gr_mat_window_clear(Cr, ctx);
}

{
gr_mat_t Ac, Ar, Cb, tmp;

gr_mat_window_init(Ac, A, 0, 2 * anr, 2 * anr, ar, ctx);
gr_mat_window_init(Ar, A, 2 * anr, 0, ar, 2 * anr, ctx);
gr_mat_window_init(Cb, C, 0, 0, 2 * anr, 2 * anr, ctx);
gr_mat_init(tmp, 2 * anr, 2 * anr, ctx);

status |= gr_mat_mul(tmp, Ac, Ar, ctx);
status |= gr_mat_add(Cb, Cb, tmp, ctx);

if (ar > 2 * anr)
gr_mat_clear(tmp, ctx);
gr_mat_window_clear(Ac, ctx);
gr_mat_window_clear(Ar, ctx);
gr_mat_window_clear(Cb, ctx);
}
}
}
else
{
gr_mat_t Ar, Br, Cr;
gr_mat_window_init(Ar, A, 2 * anr, 0, ar, ac, ctx);
gr_mat_window_init(Cr, C, 2 * anr, 0, ar, 2 * bnc, ctx);
slong anr, anc, bnr, bnc;
gr_mat_t A11, A12, A21, A22;
gr_mat_t B11, B12, B21, B22;
gr_mat_t C11, C12, C21, C22;
gr_mat_t X1, X2;

anr = ar / 2;
anc = ac / 2;
bnr = anc;
bnc = bc / 2;

gr_mat_window_init(A11, A, 0, 0, anr, anc, ctx);
gr_mat_window_init(A12, A, 0, anc, anr, 2 * anc, ctx);
gr_mat_window_init(A21, A, anr, 0, 2 * anr, anc, ctx);
gr_mat_window_init(A22, A, anr, anc, 2 * anr, 2 * anc, ctx);

gr_mat_window_init(B11, B, 0, 0, bnr, bnc, ctx);
gr_mat_window_init(B12, B, 0, bnc, bnr, 2 * bnc, ctx);
gr_mat_window_init(B21, B, bnr, 0, 2 * bnr, bnc, ctx);
gr_mat_window_init(B22, B, bnr, bnc, 2 * bnr, 2 * bnc, ctx);

gr_mat_window_init(C11, C, 0, 0, anr, bnc, ctx);
gr_mat_window_init(C12, C, 0, bnc, anr, 2 * bnc, ctx);
gr_mat_window_init(C21, C, anr, 0, 2 * anr, bnc, ctx);
gr_mat_window_init(C22, C, anr, bnc, 2 * anr, 2 * bnc, ctx);

gr_mat_init(X1, anr, FLINT_MAX(bnc, anc), ctx);
gr_mat_init(X2, anc, bnc, ctx);

X1->c = anc;

status |= gr_mat_add(X1, A22, A12, ctx);
status |= gr_mat_add(X2, B22, B12, ctx);
status |= gr_mat_mul(C21, X1, X2, ctx);

status |= gr_mat_sub(X1, A22, A21, ctx);
status |= gr_mat_sub(X2, B22, B21, ctx);
status |= gr_mat_mul(C22, X1, X2, ctx);

status |= gr_mat_add(X1, X1, A12, ctx);
status |= gr_mat_add(X2, X2, B12, ctx);
status |= gr_mat_mul(C11, X1, X2, ctx);

status |= gr_mat_sub(X1, X1, A11, ctx);
status |= gr_mat_mul(C12, X1, B12, ctx);

X1->c = bnc;
status |= gr_mat_mul(X1, A12, B21, ctx);
status |= gr_mat_add(C11, C11, X1, ctx);
status |= gr_mat_add(C12, C12, C22, ctx);
status |= gr_mat_sub(C12, C11, C12, ctx);
status |= gr_mat_sub(C11, C21, C11, ctx);
status |= gr_mat_sub(X2, X2, B11, ctx);
status |= gr_mat_mul(C21, A21, X2, ctx);

gr_mat_clear(X2, ctx);

status |= gr_mat_sub(C21, C11, C21, ctx);
status |= gr_mat_add(C22, C22, C11, ctx);
status |= gr_mat_mul(C11, A11, B11, ctx);

status |= gr_mat_add(C11, X1, C11, ctx);

X1->c = FLINT_MAX(bnc, anc);
gr_mat_clear(X1, ctx);

gr_mat_window_clear(A11, ctx);
gr_mat_window_clear(A12, ctx);
gr_mat_window_clear(A21, ctx);
gr_mat_window_clear(A22, ctx);

gr_mat_window_clear(B11, ctx);
gr_mat_window_clear(B12, ctx);
gr_mat_window_clear(B21, ctx);
gr_mat_window_clear(B22, ctx);

gr_mat_window_clear(C11, ctx);
gr_mat_window_clear(C12, ctx);
gr_mat_window_clear(C21, ctx);
gr_mat_window_clear(C22, ctx);

/* don't compute the overlapping entries twice */
if (bc > 2 * bnc)
{
gr_mat_window_init(Br, B, 0, 0, ac, 2 * bnc, ctx);
status |= gr_mat_mul(Cr, Ar, Br, ctx);
gr_mat_window_clear(Br, ctx);
gr_mat_t Bc, Cc;
gr_mat_window_init(Bc, B, 0, 2 * bnc, ac, bc, ctx);
gr_mat_window_init(Cc, C, 0, 2 * bnc, ar, bc, ctx);
status |= gr_mat_mul(Cc, A, Bc, ctx);
gr_mat_window_clear(Bc, ctx);
gr_mat_window_clear(Cc, ctx);
}
else

if (ar > 2 * anr)
{
status |= gr_mat_mul(Cr, Ar, B, ctx);
gr_mat_t Ar, Br, Cr;
gr_mat_window_init(Ar, A, 2 * anr, 0, ar, ac, ctx);
gr_mat_window_init(Cr, C, 2 * anr, 0, ar, 2 * bnc, ctx);

/* don't compute the overlapping entries twice */
if (bc > 2 * bnc)
{
gr_mat_window_init(Br, B, 0, 0, ac, 2 * bnc, ctx);
status |= gr_mat_mul(Cr, Ar, Br, ctx);
gr_mat_window_clear(Br, ctx);
}
else
{
status |= gr_mat_mul(Cr, Ar, B, ctx);
}

gr_mat_window_clear(Ar, ctx);
gr_mat_window_clear(Cr, ctx);
}

gr_mat_window_clear(Ar, ctx);
gr_mat_window_clear(Cr, ctx);
}
if (ac > 2 * anc)
{
gr_mat_t Ac, Br, Cb, tmp;
slong mt, nt;

if (ac > 2 * anc)
{
gr_mat_t Ac, Br, Cb, tmp;
slong mt, nt;

gr_mat_window_init(Ac, A, 0, 2 * anc, 2 * anr, ac, ctx);
gr_mat_window_init(Br, B, 2 * bnr, 0, ac, 2 * bnc, ctx);
gr_mat_window_init(Cb, C, 0, 0, 2 * anr, 2 * bnc, ctx);

mt = Ac->r;
nt = Br->c;

gr_mat_init(tmp, mt, nt, ctx);
status |= gr_mat_mul(tmp, Ac, Br, ctx);
status |= gr_mat_add(Cb, Cb, tmp, ctx);
gr_mat_clear(tmp, ctx);
gr_mat_window_clear(Ac, ctx);
gr_mat_window_clear(Br, ctx);
gr_mat_window_clear(Cb, ctx);
gr_mat_window_init(Ac, A, 0, 2 * anc, 2 * anr, ac, ctx);
gr_mat_window_init(Br, B, 2 * bnr, 0, ac, 2 * bnc, ctx);
gr_mat_window_init(Cb, C, 0, 0, 2 * anr, 2 * bnc, ctx);

mt = Ac->r;
nt = Br->c;

gr_mat_init(tmp, mt, nt, ctx);
status |= gr_mat_mul(tmp, Ac, Br, ctx);
status |= gr_mat_add(Cb, Cb, tmp, ctx);
gr_mat_clear(tmp, ctx);
gr_mat_window_clear(Ac, ctx);
gr_mat_window_clear(Br, ctx);
gr_mat_window_clear(Cb, ctx);
}
}

return status;
Expand Down
Loading

0 comments on commit ed0f2b0

Please sign in to comment.