Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to pull request #9 #1

Open
wants to merge 20 commits into
base: chain_iptm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions alphafold/common/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
# limitations under the License.

"""Functions for processing confidence metrics."""

from jax import debug
from functools import partial
from jax import jit
import jax.numpy as jnp
import jax
import numpy as np
from alphafold.common import residue_constants
import scipy.special
###




def compute_tol(prev_pos, current_pos, mask, use_jnp=False):
# Early stopping criteria based on criteria used in
Expand Down Expand Up @@ -197,6 +203,13 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None,
if chain_num is None:
chain_num = 1

# batch = {'asym_id': asym_id}
# apply_fn_jit = jax.jit(apply_fn)
# max_asym_id_traced = apply_fn_jit(batch)
# max_asym_id_as_int = int(max_asym_id_traced)

# jax.debug.print('max_asym_id_as_int={max_asym_id_as_int}',max_asym_id_as_int=max_asym_id_as_int)

# residue_weights has to be in [0, 1], but can be floating-point, i.e. the
# exp. resolved head's probability.
if residue_weights is None:
Expand Down Expand Up @@ -286,4 +299,4 @@ def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=F
else:
mean_score = confidence_metrics["mean_plddt"]
confidence_metrics["ranking_confidence"] = mean_score
return confidence_metrics
return confidence_metrics
12 changes: 8 additions & 4 deletions alphafold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ class RunModel:
def __init__(self,
config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None,
is_training = False):
is_training = False,
chain_num=1):

self.config = config
self.params = params
self.multimer_mode = config.model.global_config.multimer_mode
self.chain_num = chain_num

if self.multimer_mode:
def _forward_fn(batch):
model = modules_multimer.AlphaFold(self.config.model)
model = modules_multimer.AlphaFold(self.config.model,self.chain_num)
return model(batch, is_training=is_training)
else:
def _forward_fn(batch):
Expand Down Expand Up @@ -134,6 +136,7 @@ def predict(self,
Returns:
A dictionary of model outputs.
"""

self.init_params(feat)
logging.info('Running predict with shape(feat) = %s',
tree.map_structure(lambda x: x.shape, feat))
Expand All @@ -146,6 +149,7 @@ def predict(self,
else:
num_ensemble = self.config.data.eval.num_ensemble
L = aatype.shape[1]


# initialize

Expand All @@ -169,6 +173,7 @@ def _jnp_to_np(x):

# initialize random key
key = jax.random.PRNGKey(random_seed)


# iterate through recyckes
for r in range(num_iters):
Expand Down Expand Up @@ -197,6 +202,5 @@ def _jnp_to_np(x):
break
if r > 0 and result["tol"] < self.config.model.recycle_early_stop_tolerance:
break

logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result))
return result, r
return result, r
8 changes: 3 additions & 5 deletions alphafold/model/modules_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,11 @@ class AlphaFold(hk.Module):
"""AlphaFold-Multimer model with recycling.
"""

def __init__(self, config, name='alphafold'):
def __init__(self, config,chain_num, name='alphafold'):
super().__init__(name=name)
self.config = config
self.global_config = config.global_config

self.chain_num = chain_num
def __call__(
self,
batch,
Expand All @@ -418,6 +418,7 @@ def __call__(
safe_key=None):

c = self.config
chain_num = self.chain_num
impl = AlphaFoldIteration(c, self.global_config)

if safe_key is None:
Expand Down Expand Up @@ -461,9 +462,6 @@ def apply_network(prev, safe_key):
if not return_representations:
del ret['representations']

# Extract chain NUM
chain_num = c.embeddings_and_evoformer.max_relative_chain + 1

# add confidence metrics
ret.update(confidence.get_confidence_metrics(
prediction_result=ret,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

setup(
name='alphafold-colabfold',
version='2.3.6',
version='2.3.8',
long_description_content_type='text/markdown',
description='An implementation of the inference pipeline of AlphaFold v2.3.1. '
'This is a completely new model that was entered as AlphaFold2 in CASP14 '
'and published in Nature. This package contains patches for colabfold.',
author='DeepMind',
author_email='[email protected]',
license='Apache License, Version 2.0',
url='https://github.com/sokrypton/alphafold',
url='https://github.com/ntnn19/alphafold/tree/chain_iptm',
packages=find_packages(),
install_requires=[
'absl-py',
Expand Down