Skip to content

Commit

Permalink
Merge pull request #2 from innat/feat_kerasv3
Browse files Browse the repository at this point in the history
Move to KerasV3
  • Loading branch information
innat authored Mar 23, 2024
2 parents 9ac28d6 + a6a07a0 commit 0953651
Show file tree
Hide file tree
Showing 33 changed files with 4,155 additions and 2,261 deletions.
44 changes: 44 additions & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: Tests

on:
push:
workflow_call:
release:
types: [created]

permissions:
contents: read

jobs:
video_swin:
name: Test Video Swin Transformer with Keras 3
strategy:
fail-fast: false
matrix:
backend: [tensorflow, jax, torch]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: 3.9
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install -e ".[tests]" --progress-bar off --upgrade
- name: Test with pytest
run: |
pytest test/ --durations 0
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ __pycache__/
# C extensions
*.so

# Ignore large file
*.h5
*.weights.h5
*.keras

# Distribution / packaging
.Python
build/
Expand Down
10 changes: 5 additions & 5 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ cff-version: 1.2.0
title: videoswin-keras
message: >-
If you use this implementation, please cite it using the
metadata from this file
metadata from this file.
type: software
authors:
- given-names: Mohammed
Expand All @@ -14,9 +14,9 @@ authors:
identifiers:
- type: url
value: 'https://github.com/innat/VideoSwin'
description: Keras reimplementation of VideoSwin
description: Keras 3 implementation of VideoSwin
keywords:
- software
license: MIT
version: 1.0.0
date-released: '2023-10-19'
license: Apache License
version: 2.0.0
date-released: '2024-03-25'
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright 2023 Mohammed Innat
Copyright 2024 Mohammed Innat

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
40 changes: 23 additions & 17 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,42 @@

# VideoSwin Model Zoo
# Video Swin Transformer Model Zoo

Video Swin in `keras` can be used with multiple backends, i.e. `tensorflow`, `torch`, and `jax`. The input shape are expected to be `channel_last`, i.e. `(depth, height, width, channel)`.

## Note

While evaluating the video model for classification task, multiple clips from a video are sampled. This process also involves multiple crops on the sample.

- `#Frame = #input_frame x #clip x #crop`. The frame interval is `2` to evaluate on benchmark dataset.
- `#input_frame` means how many frames are input for model during the test phase.
- `#input_frame` means how many frames are input for model during the test phase. For video swin, it is `32`.
- `#crop` means spatial crops (e.g., 3 for left/right/center crop).
- `#clip` means temporal clips (e.g., 5 means repeted temporal sampling five clips with different start indices).

### Kinetics 400

In the training phase, the video swin mdoels are initialized with the pretrained weights of image swin models. In that case, `IN` referes to **ImageNet**.
# Checkpoints

| Backbone | Pretrain | #Frame | Top-1 | Top-5 | Checkpoints |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | IN-1K | 32x4x3 | 78.8 | 93.6 | [SavedModel](https://github.com/innat/VideoSwin/releases/download/v1.1/TFVideoSwinT_K400_IN1K_P244_W877_32x224.zip)/[h5](https://github.com/innat/VideoSwin/releases/download/v1.0/TFVideoSwinT_K400_IN1K_P244_W877_32x224.h5) |
| Swin-S | IN-1K | 32x4x3 | 80.6 | 94.5 | [SavedModel](https://github.com/innat/VideoSwin/releases/download/v1.1/TFVideoSwinS_K400_IN1K_P244_W877_32x224.zip)/[h5](https://github.com/innat/VideoSwin/releases/download/v1.0/TFVideoSwinS_K400_IN1K_P244_W877_32x224.h5) |
| Swin-B | IN-1K | 32x4x3 | 80.6 | 94.6 | [SavedModel](https://github.com/innat/VideoSwin/releases/download/v1.1/TFVideoSwinB_K400_IN1K_P244_W877_32x224.zip)/[h5](https://github.com/innat/VideoSwin/releases/download/v1.0/TFVideoSwinB_K400_IN1K_P244_W877_32x224.h5) |
| Swin-B | IN-22K | 32x4x3 | 82.7 | 95.5 | [SavedModel](https://github.com/innat/VideoSwin/releases/download/v1.1/TFVideoSwinB_K400_IN22K_P244_W877_32x224.zip)/[h5](https://github.com/innat/VideoSwin/releases/download/v1.0/TFVideoSwinB_K400_IN22K_P244_W877_32x224.h5) |
In the training phase, the video swin mdoels are initialized with the pretrained weights of image swin models. In that case, `IN` referes to **ImageNet**. In the following, the `keras` checkpoints are the complete model, so `keras.saving.load_model` API can be used. In contrast, the `h5` checkpoints are the only weight file.

### Kinetics 400

| Model | Pretrain | #Frame | Top-1 | Top-5 | Checkpoints | config |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | IN-1K | 32x4x3 | 78.8 | 93.6 | [keras]()/[h5]() | [swin-t](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_tiny_patch244_window877_kinetics400_1k.py) |
| Swin-S | IN-1K | 32x4x3 | 80.6 | 94.5 | [keras]()/[h5]() | [swin-s](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_small_patch244_window877_kinetics400_1k.py) |
| Swin-B | IN-1K | 32x4x3 | 80.6 | 94.6 | [keras]()/[h5]() | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics400_1k.py) |
| Swin-B | IN-22K | 32x4x3 | 82.7 | 95.5 | [keras]()/[h5]() | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics400_22k.py) |

### Kinetics 600

| Backbone | Pretrain | #Frame | Top-1 | Top-5 | Checkpoints |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-B | IN-22K | 32x4x3 | 84.0 | 96.5 | [SavedModel](https://github.com/innat/VideoSwin/releases/download/v1.1/TFVideoSwinB_K600_IN22K_P244_W877_32x224.zip)/[h5](https://github.com/innat/VideoSwin/releases/download/v1.0/TFVideoSwinB_K600_IN22K_P244_W877_32x224.h5) |
| Model | Pretrain | #Frame | Top-1 | Top-5 | Checkpoints | config |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-B | IN-22K | 32x4x3 | 84.0 | 96.5 | [keras]()/[h5]() | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics600_22k.py) |

### Something-Something V2

| Backbone | Pretrain | #Frame | Top-1 | Top-5 | Checkpoints |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-B | Kinetics 400 | 32x1x3 | 69.6 | 92.7 | [SavedModel](https://github.com/innat/VideoSwin/releases/download/v1.1/TFVideoSwinB_SSV2_K400_P244_W1677_32x224.zip)/[h5](https://github.com/innat/VideoSwin/releases/download/v1.0/TFVideoSwinB_SSV2_K400_P244_W1677_32x224.h5) |
| Model | Pretrain | #Frame | Top-1 | Top-5 | Checkpoints | config |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-B | Kinetics 400 | 32x1x3 | 69.6 | 92.7 | [keras]()/[h5]() | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window1677_sthv2.py) |


## Weight Comparison
Expand Down Expand Up @@ -57,4 +64,3 @@ np.testing.assert_allclose(x, y, 1e-4, 1e-4)
np.testing.assert_allclose(x, y, 1e-5, 1e-5)
# OK
```

99 changes: 23 additions & 76 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,14 @@

[![Palestine](https://img.shields.io/badge/Free-Palestine-white?labelColor=green)](https://twitter.com/search?q=%23FreePalestine&src=typed_query)

[![arXiv](https://img.shields.io/badge/arXiv-2106.13230-darkred)](https://arxiv.org/abs/2106.13230) [![keras-2.12.](https://img.shields.io/badge/keras-2.12-darkred)]([?](https://img.shields.io/badge/keras-2.12-darkred)) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Q7A700MEI10UomikqjQJANWyFZktJCT-?usp=sharing) [![HugginFace badge](https://img.shields.io/badge/🤗%20Hugging%20Face-Spaces-yellow.svg)](https://huggingface.co/spaces/innat/VideoSwin) [![HugginFace badge](https://img.shields.io/badge/🤗%20Hugging%20Face-Hub-yellow.svg)](https://huggingface.co/innat/videoswin)
[![arXiv](https://img.shields.io/badge/arXiv-2106.13230-darkred)](https://arxiv.org/abs/2106.13230) [![keras-3](https://img.shields.io/badge/keras-3-darkred
)]([?](https://img.shields.io/badge/keras-2.12-darkred)) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Q7A700MEI10UomikqjQJANWyFZktJCT-?usp=sharing) [![HugginFace badge](https://img.shields.io/badge/🤗%20Hugging%20Face-Spaces-yellow.svg)](https://huggingface.co/spaces/innat/VideoSwin) [![HugginFace badge](https://img.shields.io/badge/🤗%20Hugging%20Face-Hub-yellow.svg)](https://huggingface.co/innat/videoswin)


VideoSwin is a pure transformer based video modeling algorithm, attained top accuracy on the major video recognition benchmarks. In this model, the author advocates an inductive bias of locality in video transformers, which leads to a better speed-accuracy trade-off compared to previous approaches which compute self-attention globally even with spatial-temporal factorization. The locality of the proposed video architecture is realized by adapting the [**Swin Transformer**](https://arxiv.org/abs/2103.14030) designed for the image domain, while continuing to leverage the power of pre-trained image models.

This is a unofficial `Keras` implementation of [Video Swin transformers](https://arxiv.org/abs/2106.13230). The official `PyTorch` implementation is [here](https://github.com/SwinTransformer/Video-Swin-Transformer) based on [mmaction2](https://github.com/open-mmlab/mmaction2).
This is a unofficial `Keras 3` implementation of [Video Swin transformers](https://arxiv.org/abs/2106.13230). The official `PyTorch` implementation is [here](https://github.com/SwinTransformer/Video-Swin-Transformer) based on [mmaction2](https://github.com/open-mmlab/mmaction2). The official PyTorch weight has been converted to `Keras 3` compatible. This implementaiton supports to run the model on multiple backend, i.e. TensorFlow, PyTorch, and Jax.

## News

- **[05-01-2024]**: Video Swin now available in Kaggle Model with **TensorFlow (Keras 2)**, **Keras V3**, **TFLite**, and **ONNX** formats. [Link](https://www.kaggle.com/models/ipythonx/videoswin/)
- [24-11-2023]: [WIP https://github.com/innat/VideoSwin/pull/2] Supporting Keras V3 for TensorFlow, JAX, and PyTorch backend.
- **[24-10-2023]**: [Kinetics-400](https://www.deepmind.com/open-source/kinetics) test data set can be found on kaggle, [link](https://www.kaggle.com/datasets/ipythonx/k4testset/data?select=videos_val).
- **[14-10-2023]**: VideoSwin integrated into [Huggingface Space](https://huggingface.co/spaces/innat/VideoSwin).
- **[12-10-2023]**: GPU(s), TPU-VM for fine-tune training are supported, [colab](https://github.com/innat/VideoSwin/blob/main/notebooks/videoswin_video_classification.ipynb).
- **[09-10-2023]**: TensorFlow [SavedModel](https://www.tensorflow.org/guide/saved_model) (formet) checkpoints, [link](https://github.com/innat/VideoSwin/releases/tag/v1.1).
- **[08-10-2023]**: VideoSwin checkpoints [SSV2](https://developer.qualcomm.com/software/ai-datasets/something-something) and [Kinetics-600](https://www.deepmind.com/open-source/kinetics) becomes available, [link](https://github.com/innat/VideoSwin/releases/tag/v1.0).
- **[07-10-2023]**: VideoSwin checkpoints on [Kinetics-400](https://www.deepmind.com/open-source/kinetics) becomes available, [link](https://github.com/innat/VideoSwin/releases/tag/v1.0).
- **[06-10-2023]**: Code of VideoSwin in Keras becomes available.

# Install

Expand All @@ -31,25 +21,37 @@ cd VideoSwin
pip install -e .
```

# Usage
# Checkpoints

The **VideoSwin** checkpoints are available in both `.weights.h5`, and `.keras` formats. The variants of this models are `tiny`, `small`, and `base`. Check [model zoo](https://github.com/innat/VideoSwin/blob/main/MODEL_ZOO.md) page to know details of it.

The **VideoSwin** checkpoints are available in both `SavedModel` and `H5` formats. The variants of this models are `tiny`, `small`, and `base`. Check this [release](https://github.com/innat/VideoSwin/releases/tag/v1.0) and [model zoo](https://github.com/innat/VideoSwin/blob/main/MODEL_ZOO.md) page to know details of it. Following are some hightlights.

**Inference**

```python
from videoswin import VideoSwinT

>>> model = VideoSwinT(num_classes=400)
>>> model.load_weights('TFVideoSwinT_K400_IN1K_P244_W877_32x224.h5')
>>> import os
>>> import torch
>>> os.environ["KERAS_BACKEND"] = "torch"
>>> from videoswin import VideoSwinT

>>> model = VideoSwinT(
num_classes=400,
include_rescaling=False,
activation=None
)
>>> _ = model(torch.ones((1, 32, 224, 224, 3)))
>>> model.load_weights('model.weights.h5')

>>> container = read_video('sample.mp4')
>>> frames = frame_sampling(container, num_frames=32)
>>> y = model(frames)
>>> y.shape
>>> y_pred = model(frames)
>>> y_pred.shape
TensorShape([1, 400])

>>> probabilities = tf.nn.softmax(y_pred_tf)
>>> probabilities = probabilities.numpy().squeeze(0)
>>> probabilities = torch.nn.functional.softmax(y_pred).detach().numpy()
>>> probabilities = probabilities.squeeze(0)
>>> confidences = {
label_map_inv[i]: float(probabilities[i]) \
for i in np.argsort(probabilities)[::-1]
Expand Down Expand Up @@ -86,61 +88,6 @@ model.fit(...)
model.predict(...)
```

**Attention Maps**

By passing `return_attns=True` in the forward pass, we can get the attention scores from each basic block of the model as well. For example,

```python
from videomae import VideoSwinT

>>> model = VideoSwinT(num_classes=400)
>>> model.load_weights('TFVideoSwinT_K400_IN1K_P244_W877_32x224.h5')
>>> container = read_video('sample.mp4')
>>> frames = frame_sampling(container, num_frames=32)
>>> y, attns_scores = model(frames, return_attns=True)

for k, v in attns_scores.items():
print(k, v.shape) # num_heads, depth, seq_len, seq_len
TFBasicLayer1_att (128, 3, 392, 392)
TFBasicLayer2_att (32, 6, 392, 392)
TFBasicLayer3_att (8, 12, 392, 392)
TFBasicLayer4_att (2, 24, 392, 392)
```


## Model Zoo

The 3D swin-video checkpoints are listed in [`MODEL_ZOO.md`](MODEL_ZOO.md). Following are some hightlights.

### Kinetics 400

In the training phase, the video swin mdoels are initialized with the pretrained weights of image swin models. In that case, `IN` referes to **ImageNet**.

| Backbone | Pretrain | Top-1 | Top-5 | #params | FLOPs | config |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | IN-1K | 78.8 | 93.6 | 28M | ? | [swin-t](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_tiny_patch244_window877_kinetics400_1k.py) |
| Swin-S | IN-1K | 80.6 | 94.5 | 50M | ? | [swin-s](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_small_patch244_window877_kinetics400_1k.py) |
| Swin-B | IN-1K | 80.6 | 94.6 | 88M | ? | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics400_1k.py) |
| Swin-B | IN-22K | 82.7 | 95.5 | 88M | ? | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics400_22k.py) |

### Kinetics 600

| Backbone | Pretrain | Top-1 | Top-5 | #params | FLOPs | config |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-B | IN-22K | 84.0 | 96.5 | 88M | ? | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics600_22k.py) |

### Something-Something V2

| Backbone | Pretrain | Top-1 | Top-5 | #params | FLOPs | config |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-B | Kinetics 400 | 69.6 | 92.7 | 89M | ? | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window1677_sthv2.py) |


# TODO
- [x] Custom fine-tuning code.
- [x] Publish on TF-Hub or Kaggle Model.
- [ ] Support `Keras V3` to support multi-framework backend.

## Citation

If you use this videoswin implementation in your research, please cite it using the metadata from our `CITATION.cff` file.
Expand Down
Loading

0 comments on commit 0953651

Please sign in to comment.