Skip to content

Commit

Permalink
refactoring, more configs
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Sep 5, 2024
1 parent f29c36d commit 41bc9fe
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 158 deletions.
9 changes: 3 additions & 6 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
model: "runwayml/stable-diffusion-v1-5"
dtype: "float32"
device: "cuda:0"
safety_checker:
requires_safety_checker: True
low_cpu_mem_usage: False
defaults:
- diffusion_config: default
- gui_config: default

hydra:
job:
Expand Down
10 changes: 10 additions & 0 deletions conf/diffusion_config/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
model: "runwayml/stable-diffusion-v1-5"
dtype: "float32"
device: "cuda:0"
safety_checker:
requires_safety_checker: True
low_cpu_mem_usage: False

hydra:
job:
chdir: false
1 change: 1 addition & 0 deletions conf/gui_config/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
latent_update_frequency: 5
162 changes: 129 additions & 33 deletions src/AGISwarm/text2image_ms/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,63 @@
import asyncio
import base64
import logging
import multiprocessing as mp
import traceback
from functools import partial
from io import BytesIO
from pathlib import Path

import hydra
import nest_asyncio
import numpy as np
import torch
import uvicorn
from AGISwarm.asyncio_queue_manager import AsyncIOQueueManager, RequestStatus
from fastapi import APIRouter, FastAPI, WebSocket
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
from pydantic.main import BaseModel

from .diffusion_pipeline import Text2ImagePipeline
from .typing import Text2ImageGenerationConfig, Text2ImagePipelineConfig
from .typing import Config, DiffusionConfig, GUIConfig, Text2ImageGenerationConfig


def _to_task(future: asyncio.Future, as_task: bool, loop: asyncio.AbstractEventLoop):
if not as_task or isinstance(future, asyncio.Task):
return future
return loop.create_task(future)


def asyncio_run(future, as_task=True):
"""
A better implementation of `asyncio.run`.
:param future: A future or task or call of an async method.
:param as_task: Forces the future to be scheduled as task (needed for e.g. aiohttp).
"""

try:
loop = asyncio.get_running_loop()
except RuntimeError: # no event loop running:
loop = asyncio.new_event_loop()
return loop.run_until_complete(_to_task(future, as_task, loop))
else:
nest_asyncio.apply(loop)
return asyncio.run(_to_task(future, as_task, loop))


class Text2ImageApp:
"""
A class to represent the Text2Image service.
"""

def __init__(self, config: Text2ImagePipelineConfig):
def __init__(self, config: DiffusionConfig, gui_config: GUIConfig):
self.app = FastAPI()
self.setup_routes()
self.queue_manager = AsyncIOQueueManager()
self.text2image_pipeline = Text2ImagePipeline(config)
self.latent_update_frequency = gui_config.latent_update_frequency

def setup_routes(self):
"""
Expand All @@ -49,6 +80,49 @@ def setup_routes(self):
self.ws_router.post("/abort")(self.abort)
self.app.include_router(self.ws_router)

@staticmethod
def diffusion_pipeline_step_callback(
websocket: WebSocket,
request_id: str,
abort_event: asyncio.Event,
total_steps: int,
latent_update_frequency: int,
pipeline,
step: int,
timestep: int,
callback_kwargs: dict,
):
"""Callback для StableDiffusionPipeline"""
if abort_event.is_set():
raise asyncio.CancelledError("Diffusion pipeline aborted")

if step == 0 or step != total_steps and step % latent_update_frequency != 0:
return {"latents": callback_kwargs["latents"]}

with torch.no_grad():
image = pipeline.decode_latents(callback_kwargs["latents"].clone())[0]
image = Image.fromarray((image * 255).astype(np.uint8))

image_io = BytesIO()
image.save(image_io, "PNG")
dataurl = "data:image/png;base64," + base64.b64encode(
image_io.getvalue()
).decode("ascii")

asyncio_run(
websocket.send_json(
{
"request_id": request_id,
"status": RequestStatus.RUNNING,
"step": step,
"total_steps": total_steps,
"latents": dataurl,
"shape": image.size,
}
)
)
return {"latents": callback_kwargs["latents"]}

async def generate(self, websocket: WebSocket):
"""
Generate an image from a text prompt using the Text2Image pipeline.
Expand All @@ -60,39 +134,55 @@ async def generate(self, websocket: WebSocket):
while True:
await asyncio.sleep(0.01)
data = await websocket.receive_text()
print(data)
# Read generation config
gen_config = Text2ImageGenerationConfig.model_validate_json(data)
generator = self.queue_manager.queued_generator(
# Enqueue the task (without starting it)
queued_task = self.queue_manager.queued_task(
self.text2image_pipeline.generate
)
request_id = generator.request_id
interrupt_event = self.queue_manager.abort_map[request_id]

async for step_info in generator(
gen_config, interrupt_event=interrupt_event
):
await asyncio.sleep(0.01)
print(step_info)
if step_info["status"] == RequestStatus.WAITING:
await websocket.send_json(step_info)
continue
if step_info["status"] != RequestStatus.RUNNING:
await websocket.send_json(step_info)
break
latents = step_info["image"]
image_io = BytesIO()
latents.save(image_io, "PNG")
dataurl = "data:image/png;base64," + base64.b64encode(
image_io.getvalue()
).decode("ascii")

# request_id and interrupt_event are created by the queued_generator
request_id = queued_task.request_id
abort_event = self.queue_manager.abort_map[request_id]

# Diffusion step callback
callback_on_step_end = partial(
self.diffusion_pipeline_step_callback,
websocket,
request_id,
abort_event,
gen_config.num_inference_steps,
self.latent_update_frequency
)

# Start the generation task
try:
async for step_info in queued_task(
gen_config, callback_on_step_end
):
if "status" not in step_info: # Task's return value.
await websocket.send_json(
{
"status": RequestStatus.FINISHED,
}
)
break
if (
step_info["status"] == RequestStatus.WAITING
): # Queuing info returned
await websocket.send_json(step_info)
continue
if (
step_info["status"] != RequestStatus.RUNNING
): # Queuing info returned
await websocket.send_json(step_info)
break
except asyncio.CancelledError as e:
logging.info(e)
await websocket.send_json(
{
"status": RequestStatus.ABORTED,
"request_id": request_id,
"status": RequestStatus.RUNNING,
"step": step_info["step"],
"total_steps": step_info["total_steps"],
"latents": dataurl,
"shape": latents.size,
}
)
except Exception as e: # pylint: disable=broad-except
Expand All @@ -117,7 +207,7 @@ async def abort(self, request: AbortRequest):
print(f"Aborting request {request.request_id}")
await self.queue_manager.abort_task(request.request_id)

def gui(self):
async def gui(self):
"""
Get the GUI for the Text2Image service.
"""
Expand All @@ -127,12 +217,18 @@ def gui(self):


@hydra.main(config_name="config")
def main(config: Text2ImagePipelineConfig):
def main(config: Config):
"""
The main function for the Text2Image service.
"""
text2image_app = Text2ImageApp(config)
uvicorn.run(text2image_app.app, host="127.0.0.1", port=8002, log_level="debug")
text2image_app = Text2ImageApp(config.diffusion_config, config.gui_config)
uvicorn.run(
text2image_app.app,
host="127.0.0.1",
port=8002,
log_level="debug",
loop="asyncio",
)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 41bc9fe

Please sign in to comment.