Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could not perform index operation scatter #28

Open
zhangylch opened this issue Nov 17, 2024 · 4 comments
Open

Could not perform index operation scatter #28

zhangylch opened this issue Nov 17, 2024 · 4 comments

Comments

@zhangylch
Copy link

zhangylch commented Nov 17, 2024

Hi

Thank you for implementing the forward Lap. It significantly accelerates the Laplacian calculation. However, I am encountering an error when attempting to use sparsity. It seems that the sparsity is not being applied successfully. Here are the details of the issue:
Image

Next is the minimal code to reproduce the error.

import folx
import jax
import time


def fwd(x):
    x = x.reshape(-1, 3)
    distances = jnp.sqrt(jnp.sum(jnp.square(x), axis=1))
    sph = jnp.zeros((2, distances.shape[0]))
    sph = sph.at[0].set(distances)
    sph = sph.at[1].set(distances * 5.0)
    return jnp.sum(x)



key = jax.random.PRNGKey(12)
x = jax.random.normal(key, (100,300))


lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(6)(fwd)))
jax.block_until_ready(lapl(x))
start_time = time.time()
jax.block_until_ready(lapl(x))
end_time = time.time()
print(end_time - start_time)
@n-gao
Copy link
Collaborator

n-gao commented Nov 22, 2024

Scatter ops are not yet implemented for sparse jacobians. folx should still work correctly but default to a dense implementation. I'm open to accept PRs that implement scatter ops.

@zhangylch
Copy link
Author

Scatter ops are not yet implemented for sparse jacobians. folx should still work correctly but default to a dense implementation. I'm open to accept PRs that implement scatter ops.

Thanks!

@ricor07
Copy link

ricor07 commented Jan 2, 2025

Scatter ops are not yet implemented for sparse jacobians. folx should still work correctly but default to a dense implementation. I'm open to accept PRs that implement scatter ops.

Hello, may I work on this? Thank you

@n-gao
Copy link
Collaborator

n-gao commented Jan 2, 2025

Hello, may I work on this? Thank you

For sure, I look forward to your implementation! If you need help don't hesitate to ask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants