Skip to content

Commit

Permalink
Merge pull request #13 from neanes/onnx
Browse files Browse the repository at this point in the history
Add left behind changes
  • Loading branch information
danielgarthur authored Jan 2, 2025
2 parents 461f9be + 505e772 commit 88a1ce7
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 44 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/publish-model.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Train and Release Model
name: Publish Model

on:
push:
Expand Down Expand Up @@ -42,7 +42,7 @@ jobs:
- name: Train
run: |
cd scripts
python train.py --epochs=1 --version $(echo "${{github.ref_name}}" | tr -d 'v\-model')
python train.py --version $(echo "${{github.ref_name}}" | tr -d 'v\-model')
mv current_model.pth ../models/
- name: Convert to ONNX
Expand All @@ -57,4 +57,5 @@ jobs:
models/metadata.json
models/current_model.pth
models/current_model.onnx
scripts/train_log.txt
draft: true
1 change: 1 addition & 0 deletions requirements-ci-model.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
onnx==1.17.0
torch==2.5.1
torchvision==0.20.1
7 changes: 3 additions & 4 deletions scripts/convert_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import argparse
import sys
import torch
import torchvision.models as models

from torch_model import load_model

sys.path.append("../src")
from model import load_metadata
from model_metadata import load_metadata


def convert_to_onnx(model, onnx_path):
Expand Down Expand Up @@ -58,7 +57,7 @@ def convert_to_onnx(model, onnx_path):
default="../models/metadata.json",
)
args = parser.parse_args()
classes = load_metadata(args.classes)
model = load_model(args.i, classes)
metadata = load_metadata(args.meta)
model = load_model(args.i, metadata.classes)

convert_to_onnx(model, args.o)
10 changes: 5 additions & 5 deletions scripts/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from pathlib import Path
from PIL import Image

from torch_model import load_model
from torch_model import load_model, get_transform

sys.path.append("../src")
from model import load_metadata, get_transform
from model_metadata import load_metadata
from segmentation import segment
from text_removal import remove_text

Expand Down Expand Up @@ -202,8 +202,8 @@ def setup(img_transform=None, contour_filter=None):
# Run contour extraction
print("Extracting...")

classes = load_metadata(args.classes)
model = load_model(args.model, classes)
metadata = load_metadata(args.meta)
model = load_model(args.model, metadata.classes)
model.eval()

dataset = process_pdf(
Expand All @@ -212,7 +212,7 @@ def setup(img_transform=None, contour_filter=None):
args.pages,
args.o,
model,
classes,
metadata.classes,
img_transform,
contour_filter,
)
Expand Down
3 changes: 2 additions & 1 deletion scripts/dataset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from PIL import Image
from torch.utils.data import DataLoader, Dataset

from torch_model import get_transform

sys.path.append("../src")
from model import get_transform
from segmentation import segment
from text_removal import remove_text

Expand Down
10 changes: 5 additions & 5 deletions scripts/predict_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import torch.nn as nn
from PIL import Image

from torch_model import load_model
from torch_model import load_model, get_transform

sys.path.append("../src")
from model import load_metadata, get_transform
from model_metadata import load_metadata


def predict_image(model, classes, img_path):
Expand Down Expand Up @@ -64,9 +64,9 @@ def predict_image(model, classes, img_path):

args = parser.parse_args()

classes = load_metadata(args.classes)
model = load_model(args.model, classes)
metadata = load_metadata(args.meta)
model = load_model(args.model, metadata.classes)
model.eval()

prediction = predict_image(model, classes, args.infile)
prediction = predict_image(model, metadata.classes, args.infile)
print(json.dumps(prediction, indent=2))
4 changes: 2 additions & 2 deletions scripts/speed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from PIL import Image
from torch.utils.data import DataLoader, Dataset

from torch_model import load_model
from torch_model import load_model, get_transform

sys.path.append("../src")
from model import load_metadata, get_transform
from model_metadata import load_metadata
from segmentation import segment
from text_removal import remove_text

Expand Down
10 changes: 10 additions & 0 deletions scripts/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ def load_model(model_path, classes):
model.load_state_dict(torch.load(model_path, weights_only=False))
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
return model


def get_transform():
return transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from test import test_model

sys.path.append("../src")
from model import ModelMetadata
from model_metadata import ModelMetadata


class EarlyStopper:
Expand Down
24 changes: 0 additions & 24 deletions src/model.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,11 @@
import json
import numpy as np
import onnxruntime as ort


class ModelMetadata:
def __init__(self):
self.model_version = None
self.classes = None

def to_dict(self):
return {
"model_version": self.model_version,
"classes": self.classes,
}

def from_json(self, json):
self.model_version = json["model_version"]
self.classes = json["classes"]


def load_onnx_model(model_path):
return ort.InferenceSession(model_path)


def load_metadata(metadata_path):
metadata = ModelMetadata()
with open(metadata_path) as f:
metadata.from_json(json.load(f))
return metadata


def transform(img):
transformed = img.copy()
transformed = transformed / 255.0 # Normalize to [0, 1]
Expand Down
24 changes: 24 additions & 0 deletions src/model_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json


class ModelMetadata:
def __init__(self):
self.model_version = None
self.classes = None

def to_dict(self):
return {
"model_version": self.model_version,
"classes": self.classes,
}

def from_json(self, json):
self.model_version = json["model_version"]
self.classes = json["classes"]


def load_metadata(metadata_path):
metadata = ModelMetadata()
with open(metadata_path) as f:
metadata.from_json(json.load(f))
return metadata

0 comments on commit 88a1ce7

Please sign in to comment.