diff --git a/.github/workflows/ci-cd.yaml b/.github/workflows/ci-cd.yaml new file mode 100644 index 0000000..d6a2bb7 --- /dev/null +++ b/.github/workflows/ci-cd.yaml @@ -0,0 +1,128 @@ +name: replicate-ci-cd + +on: + pull_request: + branches: + - main + push: + branches: + - main + workflow_dispatch: + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + pip install ruff + + - name: Run ruff linter + run: | + ruff check + + - name: Run ruff formatter + run: | + ruff format --diff + + build-and-push: + runs-on: ubuntu-latest + if: ${{ !contains(github.event.head_commit.message, '[skip-cd]') }} + strategy: + matrix: + include: + - event: pull_request + model: 'dev', 'schnell' + env: 'test' + - event: push + model: 'dev', 'schnell' + env: 'prod' + - event: workflow_dispatch + model: 'dev', 'schnell' + env: 'prod' + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Determine changes + id: what-changed + run: | + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + FILES_CHANGED=$(git diff --name-only --diff-filter=AMR origin/${{ github.base_ref }} HEAD) + else + FILES_CHANGED=$(git diff --name-only --diff-filter=AMR HEAD~1 HEAD) + fi + echo "FILES_CHANGED=$FILES_CHANGED" >> $GITHUB_ENV + if echo "$FILES_CHANGED" | grep -q 'cog.yaml' || ${{ contains(github.event.head_commit.message, '[cog build]') }}; then + echo "cog-push=true" >> $GITHUB_OUTPUT + else + echo "cog-push=false" >> $GITHUB_OUTPUT + fi + + - name: Setup Cog + if: steps.what-changed.outputs.cog-push == 'true' + uses: replicate/setup-cog@v2 + with: + token: ${{ secrets.REPLICATE_API_TOKEN }} + install-cuda: false + cog-version: "v0.9.20" + + - name: Free disk space + if: steps.what-changed.outputs.cog-push == 'true' + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: false + swap-storage: true + + - name: Cog build and push + if: steps.what-changed.outputs.cog-push == 'true' + run: | + ./script/push.sh ${{ matrix.model }} ${{ matrix.env }} + + - name: Install Yolo + if: steps.what-changed.outputs.cog-push == 'false' + run: | + sudo curl -o /usr/local/bin/yolo -L "https://github.com/replicate/yolo/releases/latest/download/yolo_$(uname -s)_$(uname -m)" + sudo chmod +x /usr/local/bin/yolo + + - name: Yolo push + if: steps.what-changed.outputs.cog-push == 'false' + env: + REPLICATE_API_TOKEN: ${{secrets.REPLICATE_API_TOKEN}} + run: | + echo "pushing changes to ${{ matrix.model }}" + echo "changed files: $FILES_CHANGED" + yolo push --base ${{ matrix.model }} --dest ${{ matrix.model }} $FILES_CHANGED + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install Python test dependencies + run: | + pip install -r requirements_test.txt + + - name: Test model + env: + REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} + MODEL: ${{ matrix.model }} + TEST_ENV: ${{ matrix.env }} + run: | + ./script/test.sh ${{ matrix.model }} ${{ matrix.env }} diff --git a/flux/modules/conditioner.py b/flux/modules/conditioner.py index 4001fdb..ec61eee 100644 --- a/flux/modules/conditioner.py +++ b/flux/modules/conditioner.py @@ -10,10 +10,14 @@ def __init__(self, version: str, max_length: int, is_clip=False, **hf_kwargs): self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" if self.is_clip: - self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version + "/tokenizer", max_length=max_length) + self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( + version + "/tokenizer", max_length=max_length + ) self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version + "/model", **hf_kwargs) else: - self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version + "/tokenizer", max_length=max_length) + self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( + version + "/tokenizer", max_length=max_length + ) self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version + "/model", **hf_kwargs) self.hf_module = self.hf_module.eval().requires_grad_(False) diff --git a/flux/sampling.py b/flux/sampling.py index 7a87a25..eb7971a 100644 --- a/flux/sampling.py +++ b/flux/sampling.py @@ -104,7 +104,7 @@ def denoise_single_item( vec: Tensor, timesteps: list[float], guidance: float = 4.0, - compile_run: bool = False + compile_run: bool = False, ): img = img.unsqueeze(0) img_ids = img_ids.unsqueeze(0) @@ -113,14 +113,14 @@ def denoise_single_item( vec = vec.unsqueeze(0) guidance_vec = torch.full((1,), guidance, device=img.device, dtype=img.dtype) - if compile_run: - torch._dynamo.mark_dynamic(img, 1, min=256, max=8100) # needs at least torch 2.4 + if compile_run: + torch._dynamo.mark_dynamic(img, 1, min=256, max=8100) # needs at least torch 2.4 torch._dynamo.mark_dynamic(img_ids, 1, min=256, max=8100) model = torch.compile(model) for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])): t_vec = torch.full((1,), t_curr, dtype=img.dtype, device=img.device) - + pred = model( img=img, img_ids=img_ids, @@ -135,6 +135,7 @@ def denoise_single_item( return img, model + def denoise( model: Flux, # model input @@ -146,28 +147,21 @@ def denoise( # sampling parameters timesteps: list[float], guidance: float = 4.0, - compile_run: bool = False + compile_run: bool = False, ): batch_size = img.shape[0] output_imgs = [] for i in range(batch_size): denoised_img, model = denoise_single_item( - model, - img[i], - img_ids[i], - txt[i], - txt_ids[i], - vec[i], - timesteps, - guidance, - compile_run + model, img[i], img_ids[i], txt[i], txt_ids[i], vec[i], timesteps, guidance, compile_run ) compile_run = False output_imgs.append(denoised_img) - + return torch.cat(output_imgs), model + def unpack(x: Tensor, height: int, width: int) -> Tensor: return rearrange( x, diff --git a/flux/util.py b/flux/util.py index 469ae00..e624ba2 100644 --- a/flux/util.py +++ b/flux/util.py @@ -23,6 +23,7 @@ class ModelSpec: ae_path: str | None ae_url: str | None + T5_URL = "https://weights.replicate.delivery/default/official-models/flux/t5/t5-v1_1-xxl.tar" T5_CACHE = "./model-cache/t5" CLIP_URL = "https://weights.replicate.delivery/default/official-models/flux/clip/clip-vit-large-patch14.tar" @@ -191,4 +192,4 @@ def download_weights(url: str, dest: str): subprocess.check_call(["pget", "-x", url, dest], close_fds=False) else: subprocess.check_call(["pget", url, dest], close_fds=False) - print("downloading took: ", time.time() - start) \ No newline at end of file + print("downloading took: ", time.time() - start) diff --git a/integration-tests/test-model.py b/integration-tests/test-model.py new file mode 100644 index 0000000..5e04895 --- /dev/null +++ b/integration-tests/test-model.py @@ -0,0 +1,221 @@ +""" +It's tests. +Spins up a cog server and hits it with the HTTP API if you're local; runs through the python client if you're not +""" + +import base64 +import os +import subprocess +import sys +import time +import pytest +import requests +import replicate +from functools import partial +from PIL import Image +from io import BytesIO +import numpy as np + +ENVIRONMENT = os.getenv("ENVIRONMENT", "local") +LOCAL_ENDPOINT = "http://localhost:5000/predictions" +MODEL_NAME = os.getenv("MODEL_NAME", "no model configured") +IS_DEV = "dev" in MODEL_NAME + + +def local_run(model_endpoint: str, model_input: dict): + # TODO: figure this out for multi-image local predictions + st = time.time() + response = requests.post(model_endpoint, json={"input": model_input}) + et = time.time() - st + data = response.json() + + try: + datauri = data["output"] + base64_encoded_data = datauri.split(",")[1] + data = base64.b64decode(base64_encoded_data) + return et, Image.open(BytesIO(data)) + except Exception as e: + print("Error!") + print("input:", model_input) + print(data["logs"]) + raise e + + +def replicate_run(version: str, model_input: dict): + pred = replicate.predictions.create(version=version, input=model_input) + + pred.wait() + + predict_time = pred.metrics["predict_time"] + images = [] + for url in pred.output: + response = requests.get(url) + images.append(Image.open(BytesIO(response.content))) + print(pred.id) + return predict_time, images + + +def wait_for_server_to_be_ready(url, timeout=400): + """ + Waits for the server to be ready. + + Args: + - url: The health check URL to poll. + - timeout: Maximum time (in seconds) to wait for the server to be ready. + """ + start_time = time.time() + while True: + try: + response = requests.get(url) + data = response.json() + + if data["status"] == "READY": + return + elif data["status"] == "SETUP_FAILED": + raise RuntimeError("Server initialization failed with status: SETUP_FAILED") + + except requests.RequestException: + pass + + if time.time() - start_time > timeout: + raise TimeoutError("Server did not become ready in the expected time.") + + time.sleep(5) # Poll every 5 seconds + + +@pytest.fixture(scope="session") +def inference_func(): + if ENVIRONMENT == "local": + return partial(local_run, LOCAL_ENDPOINT) + elif ENVIRONMENT in {"test", "prod"}: + model = replicate.models.get(MODEL_NAME) + version = model.versions.list()[0] + return partial(replicate_run, version) + else: + raise Exception(f"env should be local, test, or prod but was {ENVIRONMENT}") + + +@pytest.fixture(scope="session", autouse=True) +def service(): + if ENVIRONMENT == "local": + print("building model") + # starts local server if we're running things locally + build_command = "cog build -t test-model".split() + subprocess.run(build_command, check=True) + container_name = "cog-test" + try: + subprocess.check_output(["docker", "inspect", '--format="{{.State.Running}}"', container_name]) + print(f"Container '{container_name}' is running. Stopping and removing...") + subprocess.check_call(["docker", "stop", container_name]) + subprocess.check_call(["docker", "rm", container_name]) + print(f"Container '{container_name}' stopped and removed.") + except subprocess.CalledProcessError: + # Container not found + print(f"Container '{container_name}' not found or not running.") + + run_command = f"docker run -d -p 5000:5000 --gpus all --name {container_name} test-model ".split() + process = subprocess.Popen(run_command, stdout=sys.stdout, stderr=sys.stderr) + + wait_for_server_to_be_ready("http://localhost:5000/health-check") + + yield + process.terminate() + process.wait() + stop_command = "docker stop cog-test".split() + subprocess.run(stop_command) + else: + yield + + +def get_time_bound(): + """entirely here to make sure we don't recompile""" + return 20 if IS_DEV else 10 + + +def test_base_generation(inference_func): + """standard generation for dev and schnell. assert that the output image has a dog in it with blip-2 or llava""" + test_example = { + "prompt": "A cool dog", + "aspect ratio": "1:1", + "num_outputs": 1, + } + time, img_out = inference_func(test_example) + img_out = img_out[0] + + assert time < get_time_bound() + assert img_out.size == (1024, 1024) + + +def test_num_outputs(inference_func): + """num_outputs = 4, assert time is about what you'd expect off of the prediction object""" + base_time = None + for n_outputs in range(1, 5): + test_example = { + "prompt": "A cool dog", + "aspect ratio": "1:1", + "num_outputs": n_outputs, + } + time, img_out = inference_func(test_example) + assert len(img_out) == n_outputs + if base_time: + assert time < base_time * n_outputs * 1.5 + if n_outputs == 1: + base_time = time + + +def test_determinism(inference_func): + """determinism - test with the same seed twice""" + test_example = {"prompt": "A cool dog", "aspect_ratio": "9:16", "num_outputs": 1, "seed": 112358} + time, out_one = inference_func(test_example) + out_one = out_one[0] + assert time < get_time_bound() + time_two, out_two = inference_func(test_example) + out_two = out_two[0] + assert time_two < get_time_bound() + assert out_one.size == (768, 1344) + + one_array = np.array(out_one, dtype=np.uint16) + two_array = np.array(out_two, dtype=np.uint16) + assert np.allclose(one_array, two_array, atol=20) + + +def test_resolutions(inference_func): + """changing resolutions - iterate through all resolutions and make sure that the output is""" + aspect_ratios = { + "1:1": (1024, 1024), + "16:9": (1344, 768), + "21:9": (1536, 640), + "3:2": (1216, 832), + "2:3": (832, 1216), + "4:5": (896, 1088), + "5:4": (1088, 896), + "9:16": (768, 1344), + "9:21": (640, 1536), + } + + for ratio, output in aspect_ratios.items(): + test_example = {"prompt": "A cool dog", "aspect_ratio": ratio, "num_outputs": 1, "seed": 112358} + + time, img_out = inference_func(test_example) + img_out = img_out[0] + assert img_out.size == output + assert time < get_time_bound() + + +def test_img2img(inference_func): + """img2img. does it work?""" + if not IS_DEV: + assert True + return + + test_example = { + "prompt": "a cool walrus", + "image": "https://replicate.delivery/pbxt/IS6z50uYJFdFeh1vCmXe9zasYbG16HqOOMETljyUJ1hmlUXU/keanu.jpeg", + } + + _, img_out = inference_func(test_example) + img_out = img_out[0] + assert img_out.size[0] % 16 == 0 + assert img_out.size[0] < 1440 + assert img_out.size[1] % 16 == 0 + assert img_out.size[1] < 1440 diff --git a/predict.py b/predict.py index 049726f..0b9002f 100644 --- a/predict.py +++ b/predict.py @@ -1,5 +1,4 @@ import os -import pickle import time from typing import Optional @@ -8,7 +7,6 @@ import torch import numpy as np -from einops import rearrange from PIL import Image from typing import List from einops import rearrange @@ -26,33 +24,37 @@ SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar" MAX_IMAGE_SIZE = 1440 + @dataclass class SharedInputs: prompt: Input = Input(description="Prompt for generated image") aspect_ratio: Input = Input( - description="Aspect ratio for the generated image", - choices=["1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"], - default="1:1") + description="Aspect ratio for the generated image", + choices=["1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"], + default="1:1", + ) num_outputs: Input = Input(description="Number of outputs to generate", default=1, le=4, ge=1) seed: Input = Input(description="Random seed. Set for reproducible generation", default=None) output_format: Input = Input( - description="Format of the output images", - choices=["webp", "jpg", "png"], - default="webp", - ) + description="Format of the output images", + choices=["webp", "jpg", "png"], + default="webp", + ) output_quality: Input = Input( - description="Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs", - default=80, - ge=0, - le=100, - ) + description="Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs", + default=80, + ge=0, + le=100, + ) disable_safety_checker: Input = Input( - description="Disable safety checker for generated images.", - default=False, + description="Disable safety checker for generated images.", + default=False, ) + SHARED_INPUTS = SharedInputs() + class Predictor(BasePredictor): def setup(self) -> None: return @@ -97,13 +99,12 @@ def base_setup(self, flow_model_name: str, compile: bool) -> None: num_outputs=1, num_inference_steps=self.num_steps, guidance=3.5, - output_format='png', + output_format="png", output_quality=80, disable_safety_checker=True, - seed=123 + seed=123, ) - def aspect_ratio_to_width_height(self, aspect_ratio: str): aspect_ratios = { "1:1": (1024, 1024), @@ -117,7 +118,7 @@ def aspect_ratio_to_width_height(self, aspect_ratio: str): "9:21": (640, 1536), } return aspect_ratios.get(aspect_ratio) - + def get_image(self, image: str): if image is None: return None @@ -130,7 +131,7 @@ def get_image(self, image: str): ) img: torch.Tensor = transform(image) return img[None, ...] - + def predict(): raise Exception("You need to instantiate a predictor for a specific flux model") @@ -143,8 +144,8 @@ def base_predict( output_quality: int, disable_safety_checker: bool, num_inference_steps: int, - guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default - image: Path = None, # img2img for flux-dev + guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default + image: Path = None, # img2img for flux-dev prompt_strength: float = 0.8, seed: Optional[int] = None, ) -> List[Path]: @@ -199,7 +200,7 @@ def base_predict( if self.offload: self.t5, self.clip = self.t5.to(torch_device), self.clip.to(torch_device) - inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=[prompt]*num_outputs) + inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=[prompt] * num_outputs) if self.offload: self.t5, self.clip = self.t5.cpu(), self.clip.cpu() @@ -210,7 +211,9 @@ def base_predict( print("Compiling") st = time.time() - x, flux = denoise(self.flux, **inp, timesteps=timesteps, guidance=guidance, compile_run=self.compile_run) + x, flux = denoise( + self.flux, **inp, timesteps=timesteps, guidance=guidance, compile_run=self.compile_run + ) if self.compile_run: print(f"Compiled in {time.time() - st}") @@ -221,7 +224,7 @@ def base_predict( self.flux.cpu() torch.cuda.empty_cache() self.ae.decoder.to(x.device) - + x = unpack(x.float(), height, width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = self.ae.decode(x) @@ -229,12 +232,17 @@ def base_predict( if self.offload: self.ae.decoder.cpu() torch.cuda.empty_cache() - - images = [Image.fromarray((127.5 * (rearrange(x[i], "c h w -> h w c").clamp(-1, 1) + 1.0)).cpu().byte().numpy()) for i in range(num_outputs)] + + images = [ + Image.fromarray( + (127.5 * (rearrange(x[i], "c h w -> h w c").clamp(-1, 1) + 1.0)).cpu().byte().numpy() + ) + for i in range(num_outputs) + ] has_nsfw_content = [False] * len(images) if not disable_safety_checker: - _, has_nsfw_content = self.run_safety_checker(images) # always on gpu - + _, has_nsfw_content = self.run_safety_checker(images) # always on gpu + output_paths = [] for i, (img, is_nsfw) in enumerate(zip(images, has_nsfw_content)): if is_nsfw: @@ -242,16 +250,18 @@ def base_predict( continue output_path = f"out-{i}.{output_format}" - save_params = {'quality': output_quality, 'optimize': True} if output_format != 'png' else {} + save_params = {"quality": output_quality, "optimize": True} if output_format != "png" else {} img.save(output_path, **save_params) output_paths.append(Path(output_path)) if not output_paths: - raise Exception("All generated images contained NSFW content. Try running it again with a different prompt.") + raise Exception( + "All generated images contained NSFW content. Try running it again with a different prompt." + ) print(f"Total safe images: {len(output_paths)} out of {len(images)}") return output_paths - + def run_safety_checker(self, images): safety_checker_input = self.feature_extractor(images, return_tensors="pt").to("cuda") np_images = [np.array(img) for img in images] @@ -261,10 +271,11 @@ def run_safety_checker(self, images): ) return image, has_nsfw_concept + class SchnellPredictor(Predictor): def setup(self) -> None: self.base_setup("flux-schnell", compile=False) - + @torch.inference_mode() def predict( self, @@ -276,30 +287,57 @@ def predict( output_quality: int = SHARED_INPUTS.output_quality, disable_safety_checker: bool = SHARED_INPUTS.disable_safety_checker, ) -> List[Path]: + return self.base_predict( + prompt, + aspect_ratio, + num_outputs, + output_format, + output_quality, + disable_safety_checker, + num_inference_steps=self.num_steps, + seed=seed, + ) - return self.base_predict(prompt, aspect_ratio, num_outputs, output_format, output_quality, disable_safety_checker, num_inference_steps=self.num_steps, seed=seed) - class DevPredictor(Predictor): def setup(self) -> None: self.base_setup("flux-dev", compile=True) - + @torch.inference_mode() def predict( self, prompt: str = SHARED_INPUTS.prompt, aspect_ratio: str = SHARED_INPUTS.aspect_ratio, - image: Path = Input(description="Input image for image to image mode. The aspect ratio of your output will match this image", default=None), - prompt_strength: float = Input(description="Prompt strength when using img2img. 1.0 corresponds to full destruction of information in image", - ge=0.0, le=1.0, default=0.80, + image: Path = Input( + description="Input image for image to image mode. The aspect ratio of your output will match this image", + default=None, + ), + prompt_strength: float = Input( + description="Prompt strength when using img2img. 1.0 corresponds to full destruction of information in image", + ge=0.0, + le=1.0, + default=0.80, ), num_outputs: int = SHARED_INPUTS.num_outputs, - num_inference_steps: int = Input(description="Number of denoising steps. Recommended range is 28-50", ge=1, le=50, default=28), + num_inference_steps: int = Input( + description="Number of denoising steps. Recommended range is 28-50", ge=1, le=50, default=28 + ), guidance: float = Input(description="Guidance for generated image", ge=0, le=10, default=3), seed: int = SHARED_INPUTS.seed, output_format: str = SHARED_INPUTS.output_format, output_quality: int = SHARED_INPUTS.output_quality, disable_safety_checker: bool = SHARED_INPUTS.disable_safety_checker, ) -> List[Path]: - - return self.base_predict(prompt, aspect_ratio, num_outputs, output_format, output_quality, disable_safety_checker, guidance=guidance, image=image, prompt_strength=prompt_strength, num_inference_steps=num_inference_steps, seed=seed) + return self.base_predict( + prompt, + aspect_ratio, + num_outputs, + output_format, + output_quality, + disable_safety_checker, + guidance=guidance, + image=image, + prompt_strength=prompt_strength, + num_inference_steps=num_inference_steps, + seed=seed, + ) diff --git a/requirements_test.txt b/requirements_test.txt new file mode 100644 index 0000000..3184d45 --- /dev/null +++ b/requirements_test.txt @@ -0,0 +1,5 @@ +numpy +pytest +replicate +requests +Pillow \ No newline at end of file diff --git a/script/test.sh b/script/test.sh new file mode 100644 index 0000000..bb8efb7 --- /dev/null +++ b/script/test.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + echo "Environment should be either 'test' or 'prod'" + echo "Model name should be either 'dev' or 'schnell'" + exit 1 +fi + +ENVIRONMENT="$2" + +# Validate environment argument +if [ "$ENVIRONMENT" != "test" ] && [ "$ENVIRONMENT" != "prod" ]; then + echo "Invalid environment. Please use 'test' or 'prod'." + exit 1 +fi + +# Conditional cog push based on environment +if [ "$ENVIRONMENT" == "test" ]; then + MODEL_NAME= "replicate-internal/flux-$MODEL_NAME" +elif [ "$ENVIRONMENT" == "prod" ]; then + MODEL_NAME= "replicate/flux-$MODEL_NAME-internal-model" +fi + +echo "Running tests on $MODEL_NAME" + +pytest -vv integration-tests/ diff --git a/torch_compile.py b/torch_compile.py index d141e03..37171a6 100644 --- a/torch_compile.py +++ b/torch_compile.py @@ -1,4 +1,3 @@ - # Imports flux-schnell model from predict.py # from flux.util import load_flow_model # flux = load_flow_model("flux-schnell", device="cuda") @@ -56,4 +55,3 @@ # 1440:1440 # torch.Size([1, 16, 180, 180]) # torch.Size([1, 8100, 64]) -_ \ No newline at end of file