Skip to content

Commit

Permalink
What if we had a big model that could do everything for fine-tuned in…
Browse files Browse the repository at this point in the history
…ference (#50)

* actually working inpainting

* multi-lora loading, fp8 & bf16

* ci/cd for lora models (#51)

* proper test config
  • Loading branch information
daanelson authored Nov 27, 2024
1 parent 83b1af7 commit 39156c5
Show file tree
Hide file tree
Showing 14 changed files with 483 additions and 129 deletions.
64 changes: 0 additions & 64 deletions .github/workflows/push-lora.yaml

This file was deleted.

46 changes: 23 additions & 23 deletions .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@ name: Push Models

on:
workflow_dispatch:
branches: [main]
inputs:
no_push:
description: 'Test only, without pushing to prod'
type: boolean
default: true
models:
description: 'Comma-separated list of models to push (schnell,dev,schnell-lora,dev-lora,hotswap-lora) or "all"'
type: string
default: 'all'

jobs:
cog-safe-push:
runs-on: ubuntu-latest-4-cores
if: github.event.workflow == null # Only run for the original workflow

steps:
- uses: actions/checkout@v3
Expand All @@ -35,24 +33,26 @@ jobs:
run: |
pip install git+https://github.com/replicate/cog-safe-push.git
- name: Select schnell
run: |
./script/select.sh schnell
- name: Run cog-safe-push on flux-schnell and optionally push to production
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
cog-safe-push -vv ${{ github.event.inputs.no_push == 'true' && '--no-push' || '' }} --config=safe-push-configs/cog-safe-push-schnell.yaml
- name: Select dev
run: |
./script/select.sh dev
- name: Run cog-safe-push on flux-dev and optionally push to production
- name: Push selected models
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
cog-safe-push -vv ${{ github.event.inputs.no_push == 'true' && '--no-push' || '' }} --config=safe-push-configs/cog-safe-push-dev.yaml
if [ "${{ inputs.models }}" = "all" ]; then
models="schnell,dev,schnell-lora,dev-lora,hotswap-lora"
else
models="${{ inputs.models }}"
fi
for model in ${models//,/ }; do
echo "==="
echo "==="
echo "=== Pushing $model"
echo "==="
echo "==="
./script/select.sh $model
cog-safe-push -vv
if [ "$model" != "hotswap-lora" ]; then
cog push r8.im/black-forest-labs/flux-$model # to get openapi schema :..(
fi
done
24 changes: 20 additions & 4 deletions flux/sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Callable
from typing import Callable, Optional

import torch
from einops import rearrange, repeat
Expand Down Expand Up @@ -104,7 +104,10 @@ def denoise_single_item(
vec: Tensor,
timesteps: list[float],
guidance: float = 4.0,
compile_run: bool = False
compile_run: bool = False,
image_latents: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
noise: Optional[Tensor] = None
):
img = img.unsqueeze(0)
img_ids = img_ids.unsqueeze(0)
Expand Down Expand Up @@ -133,6 +136,13 @@ def denoise_single_item(
)

img = img + (t_prev - t_curr) * pred.squeeze(0)
if mask is not None:
if t_prev != timesteps[-1]:
proper_noise_latents = t_prev * noise + (1.0 - t_prev) * image_latents
else:
proper_noise_latents = image_latents

img = (1 - mask) * proper_noise_latents + mask * img

return img, model

Expand All @@ -147,7 +157,10 @@ def denoise(
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
compile_run: bool = False
compile_run: bool = False,
image_latents: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
noise: Optional[Tensor] = None
):
batch_size = img.shape[0]
output_imgs = []
Expand All @@ -162,7 +175,10 @@ def denoise(
vec[i],
timesteps,
guidance,
compile_run
compile_run,
image_latents,
mask,
noise
)
compile_run = False
output_imgs.append(denoised_img)
Expand Down
6 changes: 3 additions & 3 deletions fp8/configs/config-1-flux-dev-h100.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
"scale_factor": 0.3611,
"shift_factor": 0.1159
},
"ckpt_path": "/src/model-cache/dev/dev.sft",
"ae_path": "/src/model-cache/ae/ae.sft",
"ckpt_path": "./model-cache/dev/dev.sft",
"ae_path": "./model-cache/ae/ae.sft",
"repo_id": "black-forest-labs/FLUX.1-dev",
"repo_flow": "flux1-dev.sft",
"repo_ae": "ae.sft",
"text_enc_max_length": 512,
"text_enc_path": "/src/model-cache/t5",
"text_enc_path": "./model-cache/t5",
"text_enc_device": "cuda:0",
"ae_device": "cuda:0",
"flux_device": "cuda:0",
Expand Down
6 changes: 3 additions & 3 deletions fp8/configs/config-1-flux-schnell-h100.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
"scale_factor": 0.3611,
"shift_factor": 0.1159
},
"ckpt_path": "/src/model-cache/schnell/schnell.sft",
"ae_path": "/src/model-cache/ae/ae.sft",
"ckpt_path": "./model-cache/schnell/schnell.sft",
"ae_path": "./model-cache/ae/ae.sft",
"repo_id": "black-forest-labs/FLUX.1-dev",
"repo_flow": "flux1-dev.sft",
"repo_ae": "ae.sft",
"text_enc_max_length": 256,
"text_enc_path": "/src/model-cache/t5",
"text_enc_path": "./model-cache/t5",
"text_enc_device": "cuda:0",
"ae_device": "cuda:0",
"flux_device": "cuda:0",
Expand Down
34 changes: 28 additions & 6 deletions fp8/lora_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,7 @@ def apply_lora_weight_to_module(
fused_weight = (w_orig.float() + fused_lora).to(torch.bfloat16)
return fused_weight


@torch.inference_mode()
def load_lora(model: Flux, lora_path: str | Path, lora_scale: float = 1.0):
t = time.time()
has_guidance = model.params.guidance_embed
def convert_lora_weights(lora_path: str | Path, has_guidance: bool):
logger.info(f"Loading LoRA weights for {lora_path}")
lora_weights = load_file(lora_path, device="cuda")
is_kohya = any(".lora_down.weight" in k for k in lora_weights)
Expand All @@ -516,15 +512,41 @@ def load_lora(model: Flux, lora_path: str | Path, lora_scale: float = 1.0):
else:
lora_weights = convert_from_original_flux_checkpoint(lora_weights)
logger.info("LoRA weights loaded")
return lora_weights


@torch.inference_mode()
def load_loras(model: Flux, lora_paths: list[str] | list[Path], lora_scales: list[float]):
t = time.time()
for lora, scale in zip(lora_paths, lora_scales):
load_lora(model, lora, scale)


@torch.inference_mode()
def load_lora(model: Flux, lora_path: str | Path, lora_scale: float = 1.0):
t = time.time()
has_guidance = model.params.guidance_embed

lora_weights = convert_lora_weights(lora_path, has_guidance)

f8_clones = apply_lora_to_model(model, lora_weights, lora_scale)

model.f8_clones = apply_lora_to_model(model, lora_weights, lora_scale)
logger.success(f"LoRA applied in {time.time() - t:.2}s")

if hasattr(model, "lora_weights"):
model.lora_weights.append((lora_weights, lora_scale))
else:
model.lora_weights = [(lora_weights, lora_scale)]

if hasattr(model, "f8_clones") and f8_clones is not None and model.f8_clones is not None:
# for subsequent lora loads, we only add clones for new modules
for k in f8_clones.keys():
if k not in model.f8_clones:
model.f8_clones[k] = f8_clones[k]

else:
model.f8_clones = f8_clones


@torch.inference_mode()
def unload_loras(model: Flux):
Expand Down
1 change: 1 addition & 0 deletions model-cog-configs/hotswap-lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
predict: "predict.py:HotswapPredictor"
Loading

0 comments on commit 39156c5

Please sign in to comment.