Skip to content

Commit

Permalink
[FEAT][MMM]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 4, 2024
1 parent abaffb7 commit 67d2b39
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 7 deletions.
114 changes: 107 additions & 7 deletions mm_mamba/model.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,86 @@
from torch import Tensor, nn
from zeta import RMSNorm

from mm_mamba import MultiModalMambaBlock
from mm_mamba.block import MultiModalMambaBlock


class MMM(nn.Module):
"""
MultiModalMamba model.
Args:
vocab_size (int): Size of the vocabulary.
dim (int): Dimension of the dense vectors.
depth (int): Number of layers in the model.
dropout (float): Dropout probability.
heads (int): Number of attention heads.
d_state (int): Dimension of the state.
image_size (int): Size of the input image.
patch_size (int): Size of the image patch.
encoder_dim (int): Dimension of the encoder.
encoder_depth (int): Number of layers in the encoder.
encoder_heads (int): Number of attention heads in the encoder.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Examples::
import torch
from mm_mamba.model import MMM
x = torch.randint(0, 10000, (1, 224))
img = torch.randn(1, 3, 224, 224)
model = MMM(
vocab_size=10000,
dim=512,
depth=6,
dropout=0.1,
heads=8,
d_state=512,
image_size=224,
patch_size=16,
encoder_dim=512,
encoder_depth=6,
encoder_heads=8,
)
out = model(x, img)
print(out.shape)
"""
def __init__(
self,
vocab_size: int,
dim: int,
depth: int,
dropout,
heads,
d_state,
dropout: float,
heads: int,
d_state: int,
image_size: int,
patch_size: int,
encoder_dim: int,
encoder_depth: int,
encoder_heads: int,
*args,
**kwargs,
):
super(MMM, self).__init__()
self.vocab_size = vocab_size
self.dim = dim
self.depth = depth

self.dropout = dropout
self.heads = heads
self.d_state = d_state
self.image_size = image_size
self.patch_size = patch_size
self.encoder_dim = encoder_dim
self.encoder_depth = encoder_depth
self.encoder_heads = encoder_heads

# Transforms integer indices to dense vectors of fixed size
self.embedding = nn.Embedding(vocab_size, dim)

# MultiModalMambaBlock in a list
self.layers = nn.ModuleList(
[
MultiModalMambaBlock(
Expand All @@ -30,13 +89,54 @@ def __init__(
dropout,
heads,
d_state,
image_size,
patch_size,
encoder_dim,
encoder_depth,
encoder_heads,
*args,
**kwargs,
)
]
)

# Normalization layer
self.rmsnorm = RMSNorm(dim)

self.norm = nn.LayerNorm(dim)

# Linear layer
self.lm_head = nn.Linear(dim, vocab_size, bias=False)

# Tie weights
self.lm_head.weight = self.embedding.weight

# Projection for the img
self.img_proj = nn.Linear(encoder_dim, dim)

def forward(self, text: Tensor, img: Tensor) -> Tensor:
self.embedding(text)
"""
Forward pass of the MultiModalMamba model.
Args:
text (Tensor): Input text tensor.
img (Tensor): Input image tensor.
Returns:
Tensor: Output logits.
"""
print(f"Image shape: {img.shape} and text shape: {text.shape}")

x = self.embedding(text)
print(f"Text shacpe: {x.shape}")

# Project the image
# img = self.img_proj(img)

for layer in self.layers:
x = layer(x, img) + x

x = self.norm(x)
logits = self.lm_head(x)

return logits

22 changes: 22 additions & 0 deletions model_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
from mm_mamba.model import MMM

x = torch.randint(0, 10000, (1, 224))
img = torch.randn(1, 3, 224, 224)

model = MMM(
vocab_size=10000,
dim=512,
depth=6,
dropout=0.1,
heads=8,
d_state=512,
image_size=224,
patch_size=16,
encoder_dim=512,
encoder_depth=6,
encoder_heads=8,
)

out = model(x, img)
print(out.shape)

0 comments on commit 67d2b39

Please sign in to comment.