Skip to content

Commit

Permalink
dependencies, linter
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Sep 3, 2024
1 parent d09f194 commit 8e3bdc6
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 52 deletions.
4 changes: 4 additions & 0 deletions CICD/analyze.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python -m pylint src
python -m pyright src
python -m black src --check
python -m isort src --check-only
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ dependencies = [
"pillow",
"diffusers",
"torch",
"deepspeed"
"deepspeed",
"hydra",
"fastapi",
"pillow"
]
[project.optional-dependencies]
test = ['pytest']
Expand Down
49 changes: 25 additions & 24 deletions src/AGISwarm/text2image_ms/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@

import hydra
import uvicorn
from fastapi import Body, FastAPI, WebSocket
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from hydra.core.config_store import ConfigStore
from jinja2 import Environment, FileSystemLoader
from PIL import Image

from .diffusion_pipeline import Text2ImagePipeline
from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig
Expand Down Expand Up @@ -52,37 +49,41 @@ async def generate(self, websocket: WebSocket):
"""

await websocket.accept()

try:
while True:
await asyncio.sleep(0.1)
data = await websocket.receive_text()
print (data)
print(data)
gen_config = Text2ImageGenerationConfig.model_validate_json(data)
request_id = str(uuid.uuid4())
async for step_info in self.text2image_pipeline.generate(request_id, gen_config):
if step_info['type'] == 'waiting':
async for step_info in self.text2image_pipeline.generate(
request_id, gen_config
):
if step_info["type"] == "waiting":
await websocket.send_json(step_info)
continue
latents = step_info['image']
latents = step_info["image"]

# Конвертируем латенты в base64
image_io = BytesIO()
latents.save(image_io, 'PNG')
dataurl = 'data:image/png;base64,' + base64.b64encode(image_io.getvalue()).decode('ascii')
latents.save(image_io, "PNG")
dataurl = "data:image/png;base64," + base64.b64encode(
image_io.getvalue()
).decode("ascii")
# Отправляем инфу о прогрессе и латенты
await websocket.send_json({
"type": "generation_step",
"step": step_info['step'],
"total_steps": step_info['total_steps'],
"latents": dataurl,
"shape": latents.size
})

await websocket.send_json({
"type": "generation_complete"
})
except Exception as e:
await websocket.send_json(
{
"type": "generation_step",
"step": step_info["step"],
"total_steps": step_info["total_steps"],
"latents": dataurl,
"shape": latents.size,
}
)

await websocket.send_json({"type": "generation_complete"})
except Exception as e: # pylint: disable=broad-except
logging.error(e)
traceback.print_exc()
await websocket.send_json(
Expand Down
50 changes: 28 additions & 22 deletions src/AGISwarm/text2image_ms/diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@
import numpy as np
import torch
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from diffusers.callbacks import PipelineCallback
from PIL import Image

from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig
from .utils import generation_request_queued_func


class DiffusionIteratorStreamer:

"""
A class to stream the diffusion pipeline.
Args:
timeout (Optional[Union[float, int]]): The timeout for the stream.
"""

def __init__(self, timeout: Optional[Union[float, int]] = None):
self.latents_stack = []
self.stop_signal: Optional[str] = None
Expand Down Expand Up @@ -47,6 +52,7 @@ async def __anext__(self) -> torch.Tensor | str:
raise StopAsyncIteration()
return latents

# pylint: disable=unused-argument
def callback(
self,
pipeline: DiffusionPipeline,
Expand All @@ -59,6 +65,7 @@ def callback(
self.put(callback_kwargs["latents"])
return {"latents": callback_kwargs["latents"]}

# pylint: disable=too-many-arguments
def stream(
self,
pipe: StableDiffusionPipeline,
Expand Down Expand Up @@ -91,18 +98,20 @@ def run_pipeline():
return self


# pylint: disable=too-few-public-methods
class Text2ImagePipeline:
"""
A class to generate images from text prompts using the Stable Diffusion model.
Args:
config (Text2ImagePipelineConfig): The configuration for the Diffusion Pipeline initialization.
- model (str): The model to use for generating the image.
- dtype (str): The data type to use for the model.
- device (str): The device to run the model on.
- safety_checker (str | None): The safety checker to use for the model.
- requires_safety_checker (bool): Whether the model requires a safety checker.
- low_cpu_mem_usage (bool): Whether to use low CPU memory usage.
config (Text2ImagePipelineConfig):
The configuration for the Diffusion Pipeline initialization.
- model (str): The model to use for generating the image.
- dtype (str): The data type to use for the model.
- device (str): The device to run the model on.
- safety_checker (str | None): The safety checker to use for the model.
- requires_safety_checker (bool): Whether the model requires a safety checker.
- low_cpu_mem_usage (bool): Whether to use low CPU memory usage.
"""

def __init__(self, config: Text2ImagePipelineConfig):
Expand All @@ -120,23 +129,20 @@ def __init__(self, config: Text2ImagePipelineConfig):
# self.pipeline.enable_sequential_cpu_offload()

@partial(generation_request_queued_func, wait_time=0.1)
async def generate(
self,
request_id: str,
gen_config: Text2ImageGenerationConfig
):
async def generate(self, request_id: str, gen_config: Text2ImageGenerationConfig):
"""
Generate an image from a text prompt using the Text2Image pipeline.
Args:
gen_config (Text2ImageGenerationConfig): The configuration for the Text2Image Pipeline generation.
- prompt (str): The text prompt to generate the image from.
- negative_prompt (str): The negative text prompt to generate the image from.
- num_inference_steps (int): The number of inference steps to run.
- guidance_scale (float): The guidance scale to use for the model.
- seed (int): The seed to use for the model.
- width (int): The width of the image to generate.
- height (int): The height of the image to generate.
gen_config (Text2ImageGenerationConfig):
The configuration for the Text2Image Pipeline generation.
- prompt (str): The text prompt to generate the image from.
- negative_prompt (str): The negative text prompt to generate the image from.
- num_inference_steps (int): The number of inference steps to run.
- guidance_scale (float): The guidance scale to use for the model.
- seed (int): The seed to use for the model.
- width (int): The width of the image to generate.
- height (int): The height of the image to generate.
Yields:
dict: A dictionary containing the step information for the generation.
Expand Down
2 changes: 1 addition & 1 deletion src/AGISwarm/text2image_ms/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ class Text2ImageGenerationConfig(BaseModel):
guidance_scale: float
seed: int
width: int
height: int
height: int
4 changes: 0 additions & 4 deletions src/AGISwarm/text2image_ms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

import asyncio
import threading
from abc import abstractmethod
from typing import Dict, Generic, List, Protocol, TypeVar, cast, runtime_checkable

from pydantic import BaseModel

__ABORT_EVENTS = {}
__QUEUE = []
Expand Down

0 comments on commit 8e3bdc6

Please sign in to comment.