Skip to content

Commit

Permalink
config refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Sep 8, 2024
1 parent 0ac6b1c commit 664fd15
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 22 deletions.
Empty file removed config/__init__.py
Empty file.
5 changes: 4 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
hf_model_name: "runwayml/stable-diffusion-v1-5"

defaults:
- t2i_model_config: default
- gui_config: default
- uvicorn_config: default

hydra:
job:
chdir: false
chdir: false
2 changes: 1 addition & 1 deletion config/gui_config/default.yaml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
latent_update_frequency: 5
latent_update_frequency: !!int 5
10 changes: 4 additions & 6 deletions config/t2i_model_config/default.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
model: "runwayml/stable-diffusion-v1-5"
dtype: "float32"
device: "cuda:0"
safety_checker:
requires_safety_checker: True
low_cpu_mem_usage: True
dtype: !!str "float32"
device: !!str "cuda:0"
requires_safety_checker: !!bool True
low_cpu_mem_usage: !!bool True
4 changes: 4 additions & 0 deletions config/uvicorn_config/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
host: !!str "127.0.0.1"
port: !!int 8000
log_level: !!str debug
loop: !!str asyncio
12 changes: 7 additions & 5 deletions src/AGISwarm/text2image_ms/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ def main(config: Config):
"""
The main function for the Text2Image service.
"""
text2image_app = Text2ImageApp(config.t2i_model_config, config.gui_config)
text2image_app = Text2ImageApp(
config.hf_model_name, config.t2i_model_config, config.gui_config
)
uvicorn.run(
text2image_app.app,
host="127.0.0.1",
port=8002,
log_level="debug",
loop="asyncio",
host=config.uvicorn_config.host,
port=config.uvicorn_config.port,
log_level=config.uvicorn_config.log_level,
loop=config.uvicorn_config.loop,
)


Expand Down
6 changes: 4 additions & 2 deletions src/AGISwarm/text2image_ms/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ class Text2ImageApp:
A class to represent the Text2Image service.
"""

def __init__(self, config: DiffusionConfig, gui_config: GUIConfig):
def __init__(
self, hf_model_name: str, config: DiffusionConfig, gui_config: GUIConfig
):
self.app = FastAPI()
self.setup_routes()
self.queue_manager = AsyncIOQueueManager(sleep_time=0.0001)
self.text2image_pipeline = Text2ImagePipeline(config)
self.text2image_pipeline = Text2ImagePipeline(hf_model_name, config)
self.latent_update_frequency = gui_config.latent_update_frequency
self.start_abort_lock = asyncio.Lock()

Expand Down
8 changes: 4 additions & 4 deletions src/AGISwarm/text2image_ms/diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ class Text2ImagePipeline:
- low_cpu_mem_usage (bool): Whether to use low CPU memory usage.
"""

def __init__(self, config: DiffusionConfig):
def __init__(self, hf_model_name: str, config: DiffusionConfig):
self.config = config
self.pipeline = StableDiffusionPipeline.from_pretrained(
config.model,
hf_model_name,
torch_dtype=getattr(torch, config.dtype),
safety_checker=config.safety_checker,
requires_safety_checker=config.requires_safety_checker,
safety_checker=None,
requires_safety_checker=False,
low_cpu_mem_usage=config.low_cpu_mem_usage,
).to(config.device)

Expand Down
18 changes: 15 additions & 3 deletions src/AGISwarm/text2image_ms/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass

from pydantic import BaseModel
from uvicorn.config import LoopSetupType


@dataclass
Expand All @@ -14,14 +15,23 @@ class DiffusionConfig(BaseModel):
A class to hold the configuration for the Diffusion Pipeline initialization.
"""

model: str
dtype: str
device: str
safety_checker: str | None
requires_safety_checker: bool
low_cpu_mem_usage: bool


@dataclass
class UvicornConfig(BaseModel):
"""
A class to hold the configuration for the Uvicorn.
"""

host: str
port: int
log_level: str
loop: LoopSetupType


@dataclass
class GUIConfig(BaseModel):
"""
Expand All @@ -37,8 +47,10 @@ class Config(BaseModel):
A class to hold the configuration for the Text2Image Pipeline.
"""

hf_model_name: str
t2i_model_config: DiffusionConfig
gui_config: GUIConfig
uvicorn_config: UvicornConfig


class Text2ImageGenerationConfig(BaseModel):
Expand Down

0 comments on commit 664fd15

Please sign in to comment.