Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Error with af2_interface_metrics.py #53

Open
aravinda1879 opened this issue Mar 30, 2023 · 3 comments
Open

Type Error with af2_interface_metrics.py #53

aravinda1879 opened this issue Mar 30, 2023 · 3 comments

Comments

@aravinda1879
Copy link

aravinda1879 commented Mar 30, 2023

Hi,
First of all, I am really not that familiar with Jax. My conda environment was built with distributed yml file and thus got the Jax 0.3.24 as shown below.
jax 0.3.24 pypi_0 pypi
jaxlib 0.3.24 pypi_0 pypi

However, when running the af2_interface_metrics with the silent files, I am getting the following error. Any thoughts on this? I am getting the same error when using both AF 2.3.1 and AF 2.2.4 versions. Also, af2_metrics.py works without any issue.

Traceback (most recent call last):
File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 597, in
predict_structure(tag_buffer, feature_dict_dict, binderlen_dict, initial_guess_dict, sfd_out, scorefilename)
File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 431, in predict_structure
prediction_result = jax.vmap(model_runner.apply, in_axes=(None,None,0,0))(model_runner.params,
File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 128, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 357, in apply_fn
out = f(*args, **kwargs)
TypeError: _forward_fn() takes 1 positional argument but 2 were given

Appreciate it a lot.
Thanks!

@aravinda1879
Copy link
Author

okay, after reading a little bit of the model.py

prediction_result = jax.vmap(model_runner.apply, in_axes=(None,None,0,0))(model_runner.params,
jax.random.PRNGKey(0), processed_feature_dict, processed_initial_guess_dict)

module.AlphaFold is compiled by jax's jit, and the batch should be a dictionary with inputs to the AlphaFold model.
image

However, though _forward_fn() takes 1 positional argument, currently af2_interface_metrics.py line 431 passes 2 arguments?
image
Saying that since of course I am wrong as the tutorial was done with this code, I wonder what I am missing here. Again, I am pretty new for JAX and would appreciate if there is an idiot's 101 explanation for this. :)
TIA!

@aravinda1879
Copy link
Author

I didn't see that #48 #38 were basically the same issue as this.

@jueseph
Copy link
Collaborator

jueseph commented Mar 30, 2023

see my reply in #48

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants