From 3820c0a1fe6bb757d18986b1005e7ee72e04337c Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Wed, 9 Oct 2024 08:49:08 -0700 Subject: [PATCH] Add size option to choose size with pixels --- predict.py | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/predict.py b/predict.py index 4b68c0a..5da1bab 100644 --- a/predict.py +++ b/predict.py @@ -62,14 +62,24 @@ "9:21": (640, 1536), } +# 1 megapixel sizes +SIZES = {f"{x}x{y}": (x, y) for x, y in ASPECT_RATIOS.values()} +# 0.25 megapixel sizes +SIZES.update({f"{x / 2}x{y / 2}": (x / 2, y / 2) for x, y in ASPECT_RATIOS.values()}) + @dataclass class SharedInputs: prompt: Input = Input(description="Prompt for generated image") + size: Input = Input( + description="Size of the generated image", + choices=list(SIZES.keys()), + default="1024x1024", + ) aspect_ratio: Input = Input( description="Aspect ratio for the generated image", choices=list(ASPECT_RATIOS.keys()), - default="1:1", + default=None, ) num_outputs: Input = Input( description="Number of outputs to generate", default=1, le=4, ge=1 @@ -99,7 +109,7 @@ class SharedInputs: megapixels: Input = Input( description="Approximate number of megapixels for generated image", choices=["1", "0.25"], - default="1", + default=None, ) @@ -246,11 +256,25 @@ def predict(): raise Exception("You need to instantiate a predictor for a specific flux model") def preprocess( - self, aspect_ratio: str, seed: Optional[int], megapixels: str + self, + size: str, + aspect_ratio: str | None, + seed: Optional[int], + megapixels: str | None, ) -> Dict: - width, height = ASPECT_RATIOS.get(aspect_ratio) - if megapixels == "0.25": - width, height = width // 2, height // 2 + width, height = SIZES.get(size) + + # Backwards compatibility for deprecated aspect_ratio and megapixels inputs + if aspect_ratio is not None or megapixels is not None: + # set defaults + if aspect_ratio is None: + aspect_ratio = "1024x1024" + if megapixels is None: + megapixels = "1" + + width, height = ASPECT_RATIOS.get(aspect_ratio) + if megapixels == "0.25": + width, height = width // 2, height // 2 if not seed: seed = int.from_bytes(os.urandom(2), "big") @@ -468,6 +492,7 @@ def setup(self) -> None: def predict( self, prompt: str = SHARED_INPUTS.prompt, + size: str = SHARED_INPUTS.size, aspect_ratio: str = SHARED_INPUTS.aspect_ratio, num_outputs: int = SHARED_INPUTS.num_outputs, num_inference_steps: int = Input( @@ -483,7 +508,7 @@ def predict( go_fast: bool = SHARED_INPUTS.go_fast, megapixels: str = SHARED_INPUTS.megapixels, ) -> List[Path]: - hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels) + hws_kwargs = self.preprocess(size, aspect_ratio, seed, megapixels) if go_fast and not self.disable_fp8: imgs, np_imgs = self.fp8_predict( @@ -518,6 +543,7 @@ def setup(self) -> None: def predict( self, prompt: str = SHARED_INPUTS.prompt, + size: str = SHARED_INPUTS.size, aspect_ratio: str = SHARED_INPUTS.aspect_ratio, image: Path = Input( description="Input image for image to image mode. The aspect ratio of your output will match this image", @@ -549,7 +575,7 @@ def predict( if image and go_fast: print("img2img not supported with fp8 quantization; running with bf16") go_fast = False - hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels) + hws_kwargs = self.preprocess(size, aspect_ratio, seed, megapixels) if go_fast and not self.disable_fp8: imgs, np_imgs = self.fp8_predict(