Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Registry of add_any and pad #27

Open
zgbkdlm opened this issue Nov 4, 2024 · 1 comment
Open

Registry of add_any and pad #27

zgbkdlm opened this issue Nov 4, 2024 · 1 comment

Comments

@zgbkdlm
Copy link

zgbkdlm commented Nov 4, 2024

Hi,

I have encountered the following warnings of unregistered functions. These functions add_any and pad seem to be JAX primitives and I don't find any intructions in README.md for how to register them. Do you have any example?

WARNING:[folx](/home/.venv/lib/python3.12/site-packages/folx/jvp.py:493:15 (get_jvp_function.<locals>.merged_fwd)) - add_any not in registry. The following call might be slow as we will compute the full hessian.
WARNING:[folx](/home/.venv/lib/python3.12/site-packages/folx/jvp.py:493:15 (get_jvp_function.<locals>.merged_fwd)) - add_any not in registry. The following call might be slow as we will compute the full hessian.
WARNING:[folx](/home/.venv/lib/python3.12/site-packages/folx/utils.py:98:12 (flat_wrap.<locals>.new_fn)) - pad not in registry. The following call might be slow as we will compute the full hessian.
WARNING:[folx](/home/.venv/lib/python3.12/site-packages/folx/utils.py:98:12 (flat_wrap.<locals>.new_fn)) - pad not in registry. The following call might be slow as we will compute the full hessian.
WARNING:[folx](/home/.venv/lib/python3.12/site-packages/folx/ad.py:70:39 (jacrev.<locals>.jacfun.<locals>.flat_f)) - pad not in registry. The following call might be slow as we will compute the full hessian.
WARNING:[folx](/home/.venv/lib/python3.12/site-packages/folx/ad.py:70:39 (jacrev.<locals>.jacfun.<locals>.flat_f)) - pad not in registry. The following call might be slow as we will compute the full hessian.
@n-gao
Copy link
Collaborator

n-gao commented Nov 22, 2024

Thanks a lot for your interest in folx! :)

To register a function for your own code, you can simply add

import jax
from folx import register_function, wrap_forward_laplacian

register_function(jax._src.ad_utils.add_any_p,wrap_forward_laplacian(
        jax.lax.add, flags=FunctionFlags.LINEAR, in_axes=()
    )
)

For a PR, please add it to the _LAPLACE_FN_REGISTRY dictionary in wrapped_functions.py.

For the pad primitive, it's a bit more complicated since the results depend on the padding mode. It depends on whether value is a FwdLaplArray. If that's the case, the logic needs special handling. For non-FwdLaplArray pad values, one can simply rely on the FunctionFlags.INDEXING default.

Please feel free to open a PR to fix this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants