Skip to content

Commit

Permalink
Merge pull request #53 from replicate/hotswap-fixes
Browse files Browse the repository at this point in the history
Hotswap fixes
  • Loading branch information
andreasjansson authored Nov 27, 2024
2 parents 876c2a2 + bc810d4 commit 2a6a4b2
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ class SharedInputs:
description="Disable safety checker for generated images.",
default=False,
)
go_fast: Input = Input(
description="Run faster predictions with model optimized for speed (currently fp8 quantized); disable to run in original bf16",
default=True,
)
lora_weights: Input = Input(
description="Load LoRA weights. Supports Replicate models in the format <owner>/<username> or <owner>/<username>/<version>, HuggingFace URLs in the format huggingface.co/<owner>/<model-name>, CivitAI URLs in the format civitai.com/models/<id>[/<model-name>], or arbitrary .safetensors URLs from the Internet. For example, 'fofr/flux-pixar-cars'",
default=None,
Expand All @@ -114,6 +110,17 @@ class SharedInputs:
default="1",
)

@property
def go_fast(self) -> Input:
return self.go_fast_with_default(True)

@staticmethod
def go_fast_with_default(default: bool) -> Input:
return Input(
description="Run faster predictions with model optimized for speed (currently fp8 quantized); disable to run in original bf16",
default=default,
)


SHARED_INPUTS = SharedInputs()

Expand All @@ -136,6 +143,9 @@ def base_setup(
compile_fp8: bool = False,
compile_bf16: bool = False,
disable_fp8: bool = False,
t5=None,
clip=None,
ae=None,
) -> None:
self.flow_model_name = flow_model_name
print(f"Booting model {self.flow_model_name}")
Expand Down Expand Up @@ -173,15 +183,24 @@ def base_setup(

device = "cuda"
max_length = 256 if self.flow_model_name == "flux-schnell" else 512
self.t5 = load_t5(device, max_length=max_length)
self.clip = load_clip(device)
if t5:
self.t5 = t5
else:
self.t5 = load_t5(device, max_length=max_length)
if clip:
self.clip = clip
else:
self.clip = load_clip(device)
self.flux = load_flow_model(
self.flow_model_name, device="cpu" if self.offload else device
)
self.flux = self.flux.eval()
self.ae = load_ae(
self.flow_model_name, device="cpu" if self.offload else device
)
if ae:
self.ae = ae
else:
self.ae = load_ae(
self.flow_model_name, device="cpu" if self.offload else device
)

self.num_steps = 4 if self.flow_model_name == "flux-schnell" else 28
self.shift = self.flow_model_name != "flux-schnell"
Expand Down Expand Up @@ -783,8 +802,8 @@ def predict(


class DevLoraPredictor(Predictor):
def setup(self) -> None:
self.base_setup("flux-dev", compile_fp8=True)
def setup(self, t5=None, clip=None, ae=None) -> None:
self.base_setup("flux-dev", compile_fp8=True, t5=t5, clip=clip, ae=ae)
self.lora_setup()

def predict(
Expand Down Expand Up @@ -855,7 +874,11 @@ def setup(self) -> None:
self.schnell_lora.setup()

self.dev_lora = DevLoraPredictor()
self.dev_lora.setup()
self.dev_lora.setup(
t5=self.schnell_lora.t5,
clip=self.schnell_lora.clip,
ae=self.schnell_lora.ae,
)

def predict(
self,
Expand Down Expand Up @@ -915,7 +938,7 @@ def predict(
output_format: str = SHARED_INPUTS.output_format,
output_quality: int = SHARED_INPUTS.output_quality,
disable_safety_checker: bool = SHARED_INPUTS.disable_safety_checker,
go_fast: bool = SHARED_INPUTS.go_fast,
go_fast: bool = SHARED_INPUTS.go_fast_with_default(False),
megapixels: str = SHARED_INPUTS.megapixels,
replicate_weights: str = SHARED_INPUTS.lora_weights,
lora_scale: float = SHARED_INPUTS.lora_scale,
Expand Down

0 comments on commit 2a6a4b2

Please sign in to comment.