Skip to content

Commit

Permalink
[REFACTOR][Print statement]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 6, 2024
1 parent 7b3e9cb commit b81d4a5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
18 changes: 17 additions & 1 deletion mm_mamba/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
encoder_depth: int,
encoder_heads: int,
fusion_method: str = "mlp",
expansion_rate: int = 2,
*args,
**kwargs,
):
Expand All @@ -70,6 +71,9 @@ def __init__(
self.encoder_depth = encoder_depth
self.encoder_heads = encoder_heads
self.fusion_method = fusion_method

# Hidden dim
self.hidden_dim = dim * expansion_rate

# Set up the Mamba block
self.mamba = MambaBlock(
Expand All @@ -91,7 +95,7 @@ def __init__(
self.linear = nn.Linear(encoder_dim, dim)

# VisualExpert
self.fusion_layer = VisualExpert(dim, dim * 2, dropout, heads)
self.visual_expert = VisualExpert(dim, self.hidden_dim, dropout, heads)

# MLP
self.mlp = MLP(
Expand All @@ -111,13 +115,25 @@ def forward(self, text: Tensor, img: Tensor) -> Tensor:
"""
# Encode the image, Returns the same shape as text
encoded_img = self.encoder(img, return_embeddings=True)
# print(f"Image shape: {encoded_img.shape} inside the MultiModalMambaBlock")
# Project the image embeddings to the same dimension as the text embeddings
# We need to project the 2nd dim of the image embeddings to the same dimension as the text embeddings

# if the fusion method is mlp, use the mlp to fuse the text and image embeddings
if self.fusion_method == "mlp":
fusion_layer = self.mlp(encoded_img)
fused = fusion_layer + text

# If fusion method is concat, concatenate the text and image embeddings
if self.fusion_method == "concat":
fused = torch.concat([text, encoded_img], dim=1)

if self.fusion_method == "add":
fused = encoded_img + text

if self.fusion_method == "visual_expert":
concat = torch.cat([text, encoded_img], dim=1)
fused = self.visual_expert(concat)

return self.mamba(fused)

Expand Down
15 changes: 5 additions & 10 deletions mm_mamba/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
encoder_dim: int,
encoder_depth: int,
encoder_heads: int,
fusion_method: str = "mlp",
*args,
**kwargs,
):
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
encoder_dim,
encoder_depth,
encoder_heads,
fusion_method,
*args,
**kwargs,
)
Expand All @@ -112,7 +114,7 @@ def __init__(
self.lm_head.weight = self.embedding.weight

# Projection for the img
self.img_proj = nn.Linear(encoder_dim, dim)
self.img_proj = nn.Linear(dim, dim)

def forward(self, text: Tensor, img: Tensor) -> Tensor:
"""
Expand All @@ -125,18 +127,11 @@ def forward(self, text: Tensor, img: Tensor) -> Tensor:
Returns:
Tensor: Output logits.
"""
print(
f"Image shape: {img.shape} and text shape: {text.shape}"
)

x = self.embedding(text)
print(f"Text shacpe: {x.shape}")

# Project the image
# img = self.img_proj(img)

for layer in self.layers:
x = layer(x, img) + x
x = layer(x, img) # + x
# x = x + x

x = self.norm(x)
logits = self.lm_head(x)
Expand Down
1 change: 1 addition & 0 deletions model_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
encoder_dim=512,
encoder_depth=6,
encoder_heads=8,
fusion_method="visual_expert",
)

out = model(x, img)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "mmm-zeta"
version = "0.0.5"
version = "0.0.6"
description = "MMM - Pytorch"
license = "MIT"
authors = ["Kye Gomez <[email protected]>"]
Expand Down

0 comments on commit b81d4a5

Please sign in to comment.