diff --git a/CICD/analyze.sh b/CICD/analyze.sh new file mode 100755 index 0000000..1d71f4f --- /dev/null +++ b/CICD/analyze.sh @@ -0,0 +1,4 @@ +python -m pylint src +python -m pyright src +python -m black src --check +python -m isort src --check-only \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0175518..81a428b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,10 @@ dependencies = [ "pillow", "diffusers", "torch", - "deepspeed" + "deepspeed", + "hydra", + "fastapi", + "pillow" ] [project.optional-dependencies] test = ['pytest'] diff --git a/src/AGISwarm/text2image_ms/__main__.py b/src/AGISwarm/text2image_ms/__main__.py index ea49858..53c0885 100644 --- a/src/AGISwarm/text2image_ms/__main__.py +++ b/src/AGISwarm/text2image_ms/__main__.py @@ -13,12 +13,9 @@ import hydra import uvicorn -from fastapi import Body, FastAPI, WebSocket +from fastapi import FastAPI, WebSocket from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles -from hydra.core.config_store import ConfigStore -from jinja2 import Environment, FileSystemLoader -from PIL import Image from .diffusion_pipeline import Text2ImagePipeline from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig @@ -52,37 +49,41 @@ async def generate(self, websocket: WebSocket): """ await websocket.accept() - + try: while True: await asyncio.sleep(0.1) data = await websocket.receive_text() - print (data) + print(data) gen_config = Text2ImageGenerationConfig.model_validate_json(data) request_id = str(uuid.uuid4()) - async for step_info in self.text2image_pipeline.generate(request_id, gen_config): - if step_info['type'] == 'waiting': + async for step_info in self.text2image_pipeline.generate( + request_id, gen_config + ): + if step_info["type"] == "waiting": await websocket.send_json(step_info) continue - latents = step_info['image'] - + latents = step_info["image"] + # Конвертируем латенты в base64 image_io = BytesIO() - latents.save(image_io, 'PNG') - dataurl = 'data:image/png;base64,' + base64.b64encode(image_io.getvalue()).decode('ascii') + latents.save(image_io, "PNG") + dataurl = "data:image/png;base64," + base64.b64encode( + image_io.getvalue() + ).decode("ascii") # Отправляем инфу о прогрессе и латенты - await websocket.send_json({ - "type": "generation_step", - "step": step_info['step'], - "total_steps": step_info['total_steps'], - "latents": dataurl, - "shape": latents.size - }) - - await websocket.send_json({ - "type": "generation_complete" - }) - except Exception as e: + await websocket.send_json( + { + "type": "generation_step", + "step": step_info["step"], + "total_steps": step_info["total_steps"], + "latents": dataurl, + "shape": latents.size, + } + ) + + await websocket.send_json({"type": "generation_complete"}) + except Exception as e: # pylint: disable=broad-except logging.error(e) traceback.print_exc() await websocket.send_json( diff --git a/src/AGISwarm/text2image_ms/diffusion_pipeline.py b/src/AGISwarm/text2image_ms/diffusion_pipeline.py index 1d2d8fb..9a8e0fe 100644 --- a/src/AGISwarm/text2image_ms/diffusion_pipeline.py +++ b/src/AGISwarm/text2image_ms/diffusion_pipeline.py @@ -11,7 +11,6 @@ import numpy as np import torch from diffusers import DiffusionPipeline, StableDiffusionPipeline -from diffusers.callbacks import PipelineCallback from PIL import Image from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig @@ -19,7 +18,13 @@ class DiffusionIteratorStreamer: - + """ + A class to stream the diffusion pipeline. + + Args: + timeout (Optional[Union[float, int]]): The timeout for the stream. + """ + def __init__(self, timeout: Optional[Union[float, int]] = None): self.latents_stack = [] self.stop_signal: Optional[str] = None @@ -47,6 +52,7 @@ async def __anext__(self) -> torch.Tensor | str: raise StopAsyncIteration() return latents + # pylint: disable=unused-argument def callback( self, pipeline: DiffusionPipeline, @@ -59,6 +65,7 @@ def callback( self.put(callback_kwargs["latents"]) return {"latents": callback_kwargs["latents"]} + # pylint: disable=too-many-arguments def stream( self, pipe: StableDiffusionPipeline, @@ -91,18 +98,20 @@ def run_pipeline(): return self +# pylint: disable=too-few-public-methods class Text2ImagePipeline: """ A class to generate images from text prompts using the Stable Diffusion model. Args: - config (Text2ImagePipelineConfig): The configuration for the Diffusion Pipeline initialization. - - model (str): The model to use for generating the image. - - dtype (str): The data type to use for the model. - - device (str): The device to run the model on. - - safety_checker (str | None): The safety checker to use for the model. - - requires_safety_checker (bool): Whether the model requires a safety checker. - - low_cpu_mem_usage (bool): Whether to use low CPU memory usage. + config (Text2ImagePipelineConfig): + The configuration for the Diffusion Pipeline initialization. + - model (str): The model to use for generating the image. + - dtype (str): The data type to use for the model. + - device (str): The device to run the model on. + - safety_checker (str | None): The safety checker to use for the model. + - requires_safety_checker (bool): Whether the model requires a safety checker. + - low_cpu_mem_usage (bool): Whether to use low CPU memory usage. """ def __init__(self, config: Text2ImagePipelineConfig): @@ -120,23 +129,20 @@ def __init__(self, config: Text2ImagePipelineConfig): # self.pipeline.enable_sequential_cpu_offload() @partial(generation_request_queued_func, wait_time=0.1) - async def generate( - self, - request_id: str, - gen_config: Text2ImageGenerationConfig - ): + async def generate(self, request_id: str, gen_config: Text2ImageGenerationConfig): """ Generate an image from a text prompt using the Text2Image pipeline. Args: - gen_config (Text2ImageGenerationConfig): The configuration for the Text2Image Pipeline generation. - - prompt (str): The text prompt to generate the image from. - - negative_prompt (str): The negative text prompt to generate the image from. - - num_inference_steps (int): The number of inference steps to run. - - guidance_scale (float): The guidance scale to use for the model. - - seed (int): The seed to use for the model. - - width (int): The width of the image to generate. - - height (int): The height of the image to generate. + gen_config (Text2ImageGenerationConfig): + The configuration for the Text2Image Pipeline generation. + - prompt (str): The text prompt to generate the image from. + - negative_prompt (str): The negative text prompt to generate the image from. + - num_inference_steps (int): The number of inference steps to run. + - guidance_scale (float): The guidance scale to use for the model. + - seed (int): The seed to use for the model. + - width (int): The width of the image to generate. + - height (int): The height of the image to generate. Yields: dict: A dictionary containing the step information for the generation. diff --git a/src/AGISwarm/text2image_ms/typing.py b/src/AGISwarm/text2image_ms/typing.py index 674c035..cc0c590 100644 --- a/src/AGISwarm/text2image_ms/typing.py +++ b/src/AGISwarm/text2image_ms/typing.py @@ -33,4 +33,4 @@ class Text2ImageGenerationConfig(BaseModel): guidance_scale: float seed: int width: int - height: int \ No newline at end of file + height: int diff --git a/src/AGISwarm/text2image_ms/utils.py b/src/AGISwarm/text2image_ms/utils.py index 93874e6..b8e7f82 100644 --- a/src/AGISwarm/text2image_ms/utils.py +++ b/src/AGISwarm/text2image_ms/utils.py @@ -2,10 +2,6 @@ import asyncio import threading -from abc import abstractmethod -from typing import Dict, Generic, List, Protocol, TypeVar, cast, runtime_checkable - -from pydantic import BaseModel __ABORT_EVENTS = {} __QUEUE = []