Skip to content

Commit

Permalink
fix bugs and test code
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrik-johansson committed Jan 9, 2025
1 parent c4d92fb commit 86ca0b2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 29 deletions.
22 changes: 7 additions & 15 deletions src/gr_mat/sub_scalar.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,15 @@ gr_mat_scalar_other_sub(gr_mat_t res, gr_srcptr x, gr_ctx_t x_ctx, const gr_mat_
r = gr_mat_nrows(res, ctx);
c = gr_mat_ncols(res, ctx);

if (res == mat)
{
for (i = 0; i < FLINT_MIN(r, c); i++)
status |= gr_other_sub(GR_MAT_ENTRY(res, i, i, sz), x, x_ctx, GR_MAT_ENTRY(res, i, i, sz), ctx);
}
else
for (i = 0; i < r; i++)
{
for (i = 0; i < r; i++)
for (j = 0; j < c; j++)
{
for (j = 0; j < c; j++)
{
/* todo: vectorize */
if (i == j)
status |= gr_other_sub(GR_MAT_ENTRY(res, i, j, sz), x, x_ctx, GR_MAT_ENTRY(mat, i, j, sz), ctx);
else
status |= gr_set(GR_MAT_ENTRY(res, i, j, sz), GR_MAT_ENTRY(mat, i, j, sz), ctx);
}
/* todo: vectorize */
if (i == j)
status |= gr_other_sub(GR_MAT_ENTRY(res, i, j, sz), x, x_ctx, GR_MAT_ENTRY(mat, i, j, sz), ctx);
else
status |= gr_neg(GR_MAT_ENTRY(res, i, j, sz), GR_MAT_ENTRY(mat, i, j, sz), ctx);
}
}

Expand Down
45 changes: 31 additions & 14 deletions src/gr_mat/test/t-scalar.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ TEST_FUNCTION_START(gr_mat_scalar, state)
d = gr_heap_init(ctx);
c_other = gr_heap_init(ctx_other);

GR_MUST_SUCCEED(gr_mat_randtest(A, state, ctx));
GR_MUST_SUCCEED(gr_randtest(c, state, ctx));

have_other = (gr_set_other(c_other, c, ctx, ctx_other) == GR_SUCCESS);
Expand All @@ -70,62 +71,78 @@ TEST_FUNCTION_START(gr_mat_scalar, state)
{
status = GR_SUCCESS;

gr_mat_struct * A_or_B2_alias;
gr_mat_struct * A_or_B3_alias;

if (n_randint(state, 2))
{
A_or_B2_alias = A;
A_or_B3_alias = A;
}
else
{
status |= gr_mat_set(B2, A, ctx);
status |= gr_mat_set(B3, A, ctx);
A_or_B2_alias = B2;
A_or_B3_alias = B3;
}

if (testcase == 0)
{
/* A + c == A + C */
status |= gr_mat_add(B1, A, Cmn, ctx);
status |= gr_mat_add_scalar(B2, A, c, ctx);
status |= gr_mat_add_scalar(B2, A_or_B2_alias, c, ctx);
if (have_other)
status |= gr_mat_add_scalar_other(B3, A, c_other, ctx_other, ctx);
status |= gr_mat_add_scalar_other(B3, A_or_B3_alias, c_other, ctx_other, ctx);
}
else if (testcase == 1)
{
/* c + A == C + A */
status |= gr_mat_add(B1, Cmn, A, ctx);
status |= gr_mat_scalar_add(B2, c, A, ctx);
status |= gr_mat_scalar_add(B2, c, A_or_B2_alias, ctx);
if (have_other)
status |= gr_mat_scalar_other_add(B3, c_other, ctx_other, A, ctx);
status |= gr_mat_scalar_other_add(B3, c_other, ctx_other, A_or_B3_alias, ctx);
}
else if (testcase == 2)
{
/* A - c == A - C */
status |= gr_mat_sub(B1, A, Cmn, ctx);
status |= gr_mat_sub_scalar(B2, A, c, ctx);
status |= gr_mat_sub_scalar(B2, A_or_B2_alias, c, ctx);
if (have_other)
status |= gr_mat_sub_scalar_other(B3, A, c_other, ctx_other, ctx);
status |= gr_mat_sub_scalar_other(B3, A_or_B3_alias, c_other, ctx_other, ctx);
}
else if (testcase == 3)
{
/* c - A == C - A */
status |= gr_mat_sub(B1, Cmn, A, ctx);
status |= gr_mat_scalar_sub(B2, c, A, ctx);
status |= gr_mat_scalar_sub(B2, c, A_or_B2_alias, ctx);
if (have_other)
status |= gr_mat_scalar_other_sub(B3, c_other, ctx_other, A, ctx);
status |= gr_mat_scalar_other_sub(B3, c_other, ctx_other, A_or_B3_alias, ctx);
}
else if (testcase == 4)
{
/* A * c == A * C */
status |= gr_mat_mul(B1, A, Cnn, ctx);
status |= gr_mat_mul_scalar(B2, A, c, ctx);
status |= gr_mat_mul_scalar(B2, A_or_B2_alias, c, ctx);
if (have_other)
status |= gr_mat_mul_scalar_other(B3, A, c_other, ctx_other, ctx);
status |= gr_mat_mul_scalar_other(B3, A_or_B3_alias, c_other, ctx_other, ctx);
}
else if (testcase == 5)
{
/* A * c == A * C */
status |= gr_mat_mul(B1, Cmm, A, ctx);
status |= gr_mat_scalar_mul(B2, c, A, ctx);
status |= gr_mat_scalar_mul(B2, c, A_or_B2_alias, ctx);
if (have_other)
status |= gr_mat_scalar_other_mul(B3, c_other, ctx_other, A, ctx);
status |= gr_mat_scalar_other_mul(B3, c_other, ctx_other, A_or_B3_alias, ctx);
}
else if (testcase == 6)
{
/* A / c == A * c^(-1) */
status |= gr_inv(d, c, ctx);
status |= gr_mat_mul_scalar(B1, A, d, ctx);
status |= gr_mat_div_scalar(B2, A, c, ctx);
status |= gr_mat_div_scalar(B2, A_or_B2_alias, c, ctx);
if (have_other)
status |= gr_mat_div_scalar_other(B3, A, c_other, ctx_other, ctx);
status |= gr_mat_div_scalar_other(B3, A_or_B3_alias, c_other, ctx_other, ctx);
}

if (status == GR_SUCCESS && gr_mat_equal(B1, B2, ctx) == T_FALSE)
Expand Down

0 comments on commit 86ca0b2

Please sign in to comment.