diff --git a/doc/source/gr_mat.rst b/doc/source/gr_mat.rst index 93d8d2b9fa..e87a079938 100644 --- a/doc/source/gr_mat.rst +++ b/doc/source/gr_mat.rst @@ -278,6 +278,7 @@ Arithmetic .. function:: int gr_mat_mul_classical(gr_mat_t res, const gr_mat_t mat1, const gr_mat_t mat2, gr_ctx_t ctx) int gr_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) + int gr_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) int gr_mat_mul_generic(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) int gr_mat_mul(gr_mat_t res, const gr_mat_t mat1, const gr_mat_t mat2, gr_ctx_t ctx) @@ -285,6 +286,9 @@ Arithmetic otherwise, it falls back to :func:`gr_mat_mul_generic` which currently only performs classical multiplication. + The *Waksman* algorithm assumes a commutative base ring which supports + exact division by two. + .. function:: int gr_mat_sqr(gr_mat_t res, const gr_mat_t mat, gr_ctx_t ctx) .. function:: int gr_mat_add_scalar(gr_mat_t res, const gr_mat_t mat, gr_srcptr c, gr_ctx_t ctx) diff --git a/src/gr_mat.h b/src/gr_mat.h index 68b36147f1..b6ede2c3aa 100644 --- a/src/gr_mat.h +++ b/src/gr_mat.h @@ -145,6 +145,7 @@ WARN_UNUSED_RESULT int gr_mat_div_scalar(gr_mat_t res, const gr_mat_t mat, gr_sr WARN_UNUSED_RESULT int gr_mat_mul_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); WARN_UNUSED_RESULT int gr_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); +WARN_UNUSED_RESULT int gr_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); WARN_UNUSED_RESULT int gr_mat_mul_generic(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); WARN_UNUSED_RESULT int gr_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); diff --git a/src/gr_mat/mul_waksman.c b/src/gr_mat/mul_waksman.c new file mode 100644 index 0000000000..4ba770c65d --- /dev/null +++ b/src/gr_mat/mul_waksman.c @@ -0,0 +1,140 @@ +/* + Copyright (C) 2024 Éric Schost + Copyright (C) 2024 Vincent Neiger + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "gr_mat.h" +#include "gr_vec.h" + +/* todo: division by two should be divexact by two */ +/* todo: avoid redundant additions 0 + ... */ + +int gr_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +{ + int status = GR_SUCCESS; + slong m, n, p; + slong sz = ctx->sizeof_elem; + + m = A->r; + n = A->c; + p = B->c; + + if (m == 0 || n == 0 || p == 0) + { + return gr_mat_zero(C, ctx); + } + + if (n != B->r || m != C->r || p != C->c) + { + return GR_DOMAIN; + } + + if (A == C || B == C) + { + gr_mat_t T; + gr_mat_init(T, m, p, ctx); + status |= gr_mat_mul_waksman(T, A, B, ctx); + status |= gr_mat_swap_entrywise(T, C, ctx); + gr_mat_clear(T, ctx); + return status; + } + + slong i, l, j, k; + + gr_ptr tmp, Crow, Ccol, val0, val1, val2, crow; + + GR_TMP_INIT_VEC(tmp, p + m + 4, ctx); + + Crow = tmp; + Ccol = GR_ENTRY(Crow, p, sz); + val0 = GR_ENTRY(Ccol, m, sz); + val1 = GR_ENTRY(val0, 1, sz); + val2 = GR_ENTRY(val1, 1, sz); + crow = GR_ENTRY(val2, 1, sz); + + slong np = n >> 1; + + for (i = 0; i < m; i++) + status |= _gr_vec_zero(GR_MAT_ENTRY(C, i, 0, sz), p, ctx); + + for (j = 1; j <= np; j++) + { + slong j2 = (j << 1) - 1; + + for (k = 0; k < p; k++) + { + status |= gr_add(val1, GR_MAT_ENTRY(A, 0, j2 - 1, sz), GR_MAT_ENTRY(B, j2, k, sz), ctx); + status |= gr_add(val2, GR_MAT_ENTRY(A, 0, j2, sz), GR_MAT_ENTRY(B, j2 - 1, k, sz), ctx); + status |= gr_addmul(GR_MAT_ENTRY(C, 0, k, sz), val1, val2, ctx); + + status |= gr_sub(val1, GR_MAT_ENTRY(A, 0, j2 - 1, sz), GR_MAT_ENTRY(B, j2, k, sz), ctx); + status |= gr_sub(val2, GR_MAT_ENTRY(A, 0, j2, sz), GR_MAT_ENTRY(B, j2 - 1, k, sz), ctx); + status |= gr_addmul(GR_ENTRY(Crow, k, sz), val1, val2, ctx); + } + + for (l = 1; l < m; l++) + { + status |= gr_add(val1, GR_MAT_ENTRY(A, l, j2 - 1, sz), GR_MAT_ENTRY(B, j2, 0, sz), ctx); + status |= gr_add(val2, GR_MAT_ENTRY(A, l, j2, sz), GR_MAT_ENTRY(B, j2 - 1, 0, sz), ctx); + status |= gr_addmul(GR_MAT_ENTRY(C, l, 0, sz), val1, val2, ctx); + + status |= gr_sub(val1, GR_MAT_ENTRY(A, l, j2 - 1, sz), GR_MAT_ENTRY(B, j2, 0, sz), ctx); + status |= gr_sub(val2, GR_MAT_ENTRY(A, l, j2, sz), GR_MAT_ENTRY(B, j2 - 1, 0, sz), ctx); + status |= gr_addmul(GR_ENTRY(Ccol, l, sz), val1, val2, ctx); + } + + for (k = 1; k < p; k++) + { + for (l = 1; l < m; l++) + { + status |= gr_add(val1, GR_MAT_ENTRY(A, l, j2 - 1, sz), GR_MAT_ENTRY(B, j2, k, sz), ctx); + status |= gr_add(val2, GR_MAT_ENTRY(A, l, j2, sz), GR_MAT_ENTRY(B, j2 - 1, k, sz), ctx); + status |= gr_addmul(GR_MAT_ENTRY(C, l, k, sz), val1, val2, ctx); + } + } + } + + for (l = 1; l < m; l++) + { + status |= gr_add(val1, GR_ENTRY(Ccol, l, sz), GR_MAT_ENTRY(C, l, 0, sz), ctx); + status |= gr_mul_2exp_si(GR_ENTRY(Ccol, l, sz), val1, -1, ctx); + status |= gr_sub(GR_MAT_ENTRY(C, l, 0, sz), GR_MAT_ENTRY(C, l, 0, sz), GR_ENTRY(Ccol, l, sz), ctx); + } + + status |= gr_add(val1, Crow, GR_MAT_ENTRY(C, 0, 0, sz), ctx); + status |= gr_mul_2exp_si(val0, val1, -1, ctx); + status |= gr_sub(GR_MAT_ENTRY(C, 0, 0, sz), GR_MAT_ENTRY(C, 0, 0, sz), val0, ctx); + + for (k = 1; k < p; k++) + { + status |= gr_add(crow, GR_ENTRY(Crow, k, sz), GR_MAT_ENTRY(C, 0, k, sz), ctx); + status |= gr_mul_2exp_si(val1, crow, -1, ctx); + status |= gr_sub(GR_MAT_ENTRY(C, 0, k, sz), GR_MAT_ENTRY(C, 0, k, sz), val1, ctx); + status |= gr_sub(crow, val1, val0, ctx); + + for (l = 1; l < m; l++) + { + status |= gr_sub(val2, GR_MAT_ENTRY(C, l, k, sz), crow, ctx); + status |= gr_sub(GR_MAT_ENTRY(C, l, k, sz), val2, GR_ENTRY(Ccol, l, sz), ctx); + } + } + + if ((n & 1) == 1) + for (l = 0; l < m; l++) + for (k = 0; k < p; k++) + status |= gr_addmul(GR_MAT_ENTRY(C, l, k, sz), + GR_MAT_ENTRY(A, l, n - 1, sz), GR_MAT_ENTRY(B, n - 1, k, sz), ctx); + + GR_TMP_CLEAR_VEC(tmp, p + m + 4, ctx); + + return status; +} + diff --git a/src/gr_mat/test/main.c b/src/gr_mat/test/main.c index b4e197be07..058517cd98 100644 --- a/src/gr_mat/test/main.c +++ b/src/gr_mat/test/main.c @@ -35,6 +35,7 @@ #include "t-lu_recursive.c" #include "t-minpoly_field.c" #include "t-mul_strassen.c" +#include "t-mul_waksman.c" #include "t-nullspace.c" #include "t-properties.c" #include "t-randrank.c" @@ -82,6 +83,7 @@ test_struct tests[] = TEST_FUNCTION(gr_mat_lu_recursive), TEST_FUNCTION(gr_mat_minpoly_field), TEST_FUNCTION(gr_mat_mul_strassen), + TEST_FUNCTION(gr_mat_mul_waksman), TEST_FUNCTION(gr_mat_nullspace), TEST_FUNCTION(gr_mat_properties), TEST_FUNCTION(gr_mat_randrank), diff --git a/src/gr_mat/test/t-mul_waksman.c b/src/gr_mat/test/t-mul_waksman.c new file mode 100644 index 0000000000..3e12b77e38 --- /dev/null +++ b/src/gr_mat/test/t-mul_waksman.c @@ -0,0 +1,97 @@ +/* + Copyright (C) 2022 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "test_helpers.h" +#include "ulong_extras.h" +#include "gr_mat.h" + +TEST_FUNCTION_START(gr_mat_mul_waksman, state) +{ + slong iter; + + for (iter = 0; iter < 1000; iter++) + { + gr_ctx_t ctx; + gr_mat_t A, B, C, D; + slong a, b, c; + int status = GR_SUCCESS; + int can_div2; + + if (n_randint(state, 2)) + { + gr_ctx_init_fmpz(ctx); + can_div2 = 1; + } + else + { + ulong m = n_randtest_not_zero(state); + can_div2 = m % 2; + gr_ctx_init_nmod(ctx, m); + } + + a = n_randint(state, 8); + b = n_randint(state, 2) ? a : n_randint(state, 8); + c = n_randint(state, 2) ? a : n_randint(state, 8); + + gr_mat_init(A, a, b, ctx); + gr_mat_init(B, b, c, ctx); + gr_mat_init(C, a, c, ctx); + gr_mat_init(D, a, c, ctx); + + status |= gr_mat_randtest(A, state, ctx); + status |= gr_mat_randtest(B, state, ctx); + status |= gr_mat_randtest(C, state, ctx); + status |= gr_mat_randtest(D, state, ctx); + + if (a == b && b == c && n_randint(state, 2)) + { + status |= gr_mat_set(B, A, ctx); + status |= gr_mat_mul_waksman(C, A, A, ctx); + } + else if (b == c && n_randint(state, 2)) + { + status |= gr_mat_set(C, A, ctx); + status |= gr_mat_mul_waksman(C, C, B, ctx); + } + else if (a == b && n_randint(state, 2)) + { + status |= gr_mat_set(C, B, ctx); + status |= gr_mat_mul_waksman(C, A, C, ctx); + } + else + { + status |= gr_mat_mul_waksman(C, A, B, ctx); + } + + status |= gr_mat_mul_classical(D, A, B, ctx); + + if ((can_div2 && (status != GR_SUCCESS || gr_mat_equal(C, D, ctx) != T_TRUE)) + || (status == GR_SUCCESS && gr_mat_equal(C, D, ctx) != T_TRUE)) + { + flint_printf("FAIL:\n"); + gr_ctx_println(ctx); + flint_printf("A:\n"); gr_mat_print(A, ctx); flint_printf("\n\n"); + flint_printf("B:\n"); gr_mat_print(B, ctx); flint_printf("\n\n"); + flint_printf("C:\n"); gr_mat_print(C, ctx); flint_printf("\n\n"); + flint_printf("D:\n"); gr_mat_print(D, ctx); flint_printf("\n\n"); + flint_abort(); + } + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + gr_mat_clear(D, ctx); + + gr_ctx_clear(ctx); + } + + TEST_FUNCTION_END(state); +}