diff --git a/inference/core/version.py b/inference/core/version.py index 00d2732ad..4a5f53a40 100644 --- a/inference/core/version.py +++ b/inference/core/version.py @@ -1,4 +1,4 @@ -__version__ = "0.30.0" +__version__ = "0.31.0" if __name__ == "__main__": diff --git a/inference/core/workflows/core_steps/loader.py b/inference/core/workflows/core_steps/loader.py index d3b1296b9..4624dcc75 100644 --- a/inference/core/workflows/core_steps/loader.py +++ b/inference/core/workflows/core_steps/loader.py @@ -141,9 +141,15 @@ from inference.core.workflows.core_steps.fusion.dimension_collapse.v1 import ( DimensionCollapseBlockV1, ) +from inference.core.workflows.core_steps.math.cosine_similarity.v1 import ( + CosineSimilarityBlockV1, +) from inference.core.workflows.core_steps.models.foundation.anthropic_claude.v1 import ( AnthropicClaudeBlockV1, ) +from inference.core.workflows.core_steps.models.foundation.clip.v1 import ( + ClipModelBlockV1, +) from inference.core.workflows.core_steps.models.foundation.clip_comparison.v1 import ( ClipComparisonBlockV1, ) @@ -479,6 +485,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]: DimensionCollapseBlockV1, FirstNonEmptyOrDefaultBlockV1, AnthropicClaudeBlockV1, + CosineSimilarityBlockV1, BackgroundColorVisualizationBlockV1, BarcodeDetectorBlockV1, BlurVisualizationBlockV1, @@ -489,6 +496,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]: CircleVisualizationBlockV1, ClipComparisonBlockV1, ClipComparisonBlockV2, + ClipModelBlockV1, CogVLMBlockV1, ColorVisualizationBlockV1, ConvertGrayscaleBlockV1, diff --git a/inference/core/workflows/core_steps/math/cosine_similarity/__init__.py b/inference/core/workflows/core_steps/math/cosine_similarity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/inference/core/workflows/core_steps/math/cosine_similarity/v1.py b/inference/core/workflows/core_steps/math/cosine_similarity/v1.py new file mode 100644 index 000000000..fcf5892b3 --- /dev/null +++ b/inference/core/workflows/core_steps/math/cosine_similarity/v1.py @@ -0,0 +1,75 @@ +from typing import List, Literal, Optional, Type + +from pydantic import ConfigDict, Field + +from inference.core.utils.postprocess import cosine_similarity +from inference.core.workflows.execution_engine.entities.base import OutputDefinition +from inference.core.workflows.execution_engine.entities.types import ( + EMBEDDING_KIND, + FLOAT_KIND, + Selector, +) +from inference.core.workflows.prototypes.block import ( + BlockResult, + WorkflowBlock, + WorkflowBlockManifest, +) + +LONG_DESCRIPTION = """ +Calculate the cosine similarity between two embeddings. + +A cosine similarity of 1 means the two embeddings are identical, +while a cosine similarity of 0 means the two embeddings are orthogonal. +Greater values indicate greater similarity. +""" + + +class BlockManifest(WorkflowBlockManifest): + model_config = ConfigDict( + json_schema_extra={ + "name": "Cosine Similarity", + "version": "v1", + "short_description": "Calculate the cosine similarity between two embeddings.", + "long_description": LONG_DESCRIPTION, + "license": "MIT", + "block_type": "math", + "ui_manifest": { + "section": "advanced", + "icon": "far fa-calculator-simple", + "blockPriority": 3, + }, + } + ) + type: Literal["roboflow_core/cosine_similarity@v1"] + name: str = Field(description="Unique name of step in workflows") + embedding_1: Selector(kind=[EMBEDDING_KIND]) = Field( + description="Embedding 1", + examples=["$steps.clip_image.embedding"], + ) + embedding_2: Selector(kind=[EMBEDDING_KIND]) = Field( + description="Embedding 2", + examples=["$steps.clip_text.embedding"], + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [OutputDefinition(name="similarity", kind=[FLOAT_KIND])] + + @classmethod + def get_execution_engine_compatibility(cls) -> Optional[str]: + return ">=1.3.0,<2.0.0" + + +class CosineSimilarityBlockV1(WorkflowBlock): + @classmethod + def get_manifest(cls) -> Type[WorkflowBlockManifest]: + return BlockManifest + + def run(self, embedding_1: List[float], embedding_2: List[float]) -> BlockResult: + if len(embedding_1) != len(embedding_2): + raise RuntimeError( + f"roboflow_core/cosine_similarity@v1 block feed with different shape of embeddings. " + f"`embedding_1`: (N, {len(embedding_1)}), `embedding_2`: (N, {len(embedding_2)})" + ) + similarity = cosine_similarity(embedding_1, embedding_2) + return {"similarity": similarity} diff --git a/inference/core/workflows/core_steps/models/foundation/clip/__init__.py b/inference/core/workflows/core_steps/models/foundation/clip/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/inference/core/workflows/core_steps/models/foundation/clip/v1.py b/inference/core/workflows/core_steps/models/foundation/clip/v1.py new file mode 100644 index 000000000..9e92b1ce5 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/clip/v1.py @@ -0,0 +1,199 @@ +from functools import partial +from typing import List, Literal, Optional, Type, Union + +from pydantic import ConfigDict, Field + +from inference.core.entities.requests.clip import ( + ClipImageEmbeddingRequest, + ClipTextEmbeddingRequest, +) +from inference.core.env import ( + HOSTED_CORE_MODEL_URL, + LOCAL_INFERENCE_API_URL, + WORKFLOWS_REMOTE_API_TARGET, + WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, +) +from inference.core.managers.base import ModelManager +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.utils import ( + load_core_model, + run_in_parallel, +) +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + OutputDefinition, + WorkflowImageData, +) +from inference.core.workflows.execution_engine.entities.types import ( + EMBEDDING_KIND, + IMAGE_KIND, + STRING_KIND, + Selector, +) +from inference.core.workflows.prototypes.block import ( + BlockResult, + WorkflowBlock, + WorkflowBlockManifest, +) +from inference_sdk import InferenceHTTPClient + +LONG_DESCRIPTION = """ +Use a CLIP model to create semantic embeddings of text and images. + +This block accepts an image or string and returns an embedding. +The embedding can be used to compare the similarity between different +images or between images and text. +""" + + +class BlockManifest(WorkflowBlockManifest): + model_config = ConfigDict( + json_schema_extra={ + "name": "CLIP Embedding Model", + "version": "v1", + "short_description": "Generate an embedding of an image or string.", + "long_description": LONG_DESCRIPTION, + "license": "MIT", + "block_type": "model", + "ui_manifest": { + "section": "model", + "icon": "far fa-paperclip", + "blockPriority": 2, + }, + } + ) + type: Literal["roboflow_core/clip@v1"] + name: str = Field(description="Unique name of step in workflows") + data: Union[Selector(kind=[IMAGE_KIND, STRING_KIND]), str] = Field( + title="Data", + description="The string or image to generate an embedding for.", + examples=["$inputs.image", "$steps.cropping.crops"], + ) + + version: Union[ + Literal[ + "RN101", + "RN50", + "RN50x16", + "RN50x4", + "RN50x64", + "ViT-B-16", + "ViT-B-32", + "ViT-L-14-336px", + "ViT-L-14", + ], + Selector(kind=[STRING_KIND]), + ] = Field( + default="ViT-B-32", + description="Variant of CLIP model", + examples=["ViT-B-16", "$inputs.variant"], + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [OutputDefinition(name="embedding", kind=[EMBEDDING_KIND])] + + @classmethod + def get_execution_engine_compatibility(cls) -> Optional[str]: + return ">=1.3.0,<2.0.0" + + +class ClipModelBlockV1(WorkflowBlock): + + def __init__( + self, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, + ): + self._model_manager = model_manager + self._api_key = api_key + self._step_execution_mode = step_execution_mode + + @classmethod + def get_init_parameters(cls) -> List[str]: + return ["model_manager", "api_key", "step_execution_mode"] + + @classmethod + def get_manifest(cls) -> Type[WorkflowBlockManifest]: + return BlockManifest + + def run( + self, + data: Union[WorkflowImageData, str], + version: str, + ) -> BlockResult: + if self._step_execution_mode is StepExecutionMode.LOCAL: + return self.run_locally(data=data, version=version) + elif self._step_execution_mode is StepExecutionMode.REMOTE: + return self.run_remotely(data=data, version=version) + else: + raise ValueError( + f"Unknown step execution mode: {self._step_execution_mode}" + ) + + def run_locally( + self, + data: Union[WorkflowImageData, str], + version: str, + ) -> BlockResult: + if isinstance(data, str): + inference_request = ClipTextEmbeddingRequest( + clip_version_id=version, + text=[data], + api_key=self._api_key, + ) + clip_model_id = load_core_model( + model_manager=self._model_manager, + inference_request=inference_request, + core_model="clip", + ) + predictions = self._model_manager.infer_from_request_sync( + clip_model_id, inference_request + ) + return {"embedding": predictions.embeddings[0]} + else: + inference_request = ClipImageEmbeddingRequest( + clip_version_id=version, + image=[data.to_inference_format(numpy_preferred=True)], + api_key=self._api_key, + ) + clip_model_id = load_core_model( + model_manager=self._model_manager, + inference_request=inference_request, + core_model="clip", + ) + predictions = self._model_manager.infer_from_request_sync( + clip_model_id, inference_request + ) + return {"embedding": predictions.embeddings[0]} + + def run_remotely( + self, + data: Union[WorkflowImageData, str], + version: str, + ) -> BlockResult: + api_url = ( + LOCAL_INFERENCE_API_URL + if WORKFLOWS_REMOTE_API_TARGET != "hosted" + else HOSTED_CORE_MODEL_URL + ) + client = InferenceHTTPClient( + api_url=api_url, + api_key=self._api_key, + ) + if WORKFLOWS_REMOTE_API_TARGET == "hosted": + client.select_api_v0() + + if isinstance(data, str): + result = client.get_clip_text_embeddings( + text=data, + clip_version=version, + ) + else: + result = client.get_clip_image_embeddings( + inference_input=data.base64_image, + clip_version=version, + ) + + return {"embedding": result["embeddings"][0]} diff --git a/inference/core/workflows/execution_engine/entities/types.py b/inference/core/workflows/execution_engine/entities/types.py index d70948857..c51e99147 100644 --- a/inference/core/workflows/execution_engine/entities/types.py +++ b/inference/core/workflows/execution_engine/entities/types.py @@ -210,6 +210,22 @@ def __hash__(self) -> int: internal_data_type="List[Any]", ) +EMBEDDING_KIND_DOCS = """ +This kind represents a vector embedding. It is a list of floating point numbers. + +Embeddings are used in various machine learning tasks like clustering, classification, +and similarity search. They are used to represent data in a continuous, low-dimensional space. + +Typically, vectors that are close to each other in the embedding space are considered similar. +""" +EMBEDDING_KIND = Kind( + name="embedding", + description="A list of floating point numbers representing a vector embedding.", + docs=EMBEDDING_KIND_DOCS, + serialised_data_type="List[float]", + internal_data_type="List[float]", +) + RGB_COLOR_KIND_DOCS = """ This kind represents RGB color as a tuple (R, G, B). diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_clip.py b/tests/workflows/integration_tests/execution/test_workflow_with_clip.py index 944f50d0c..0ffd9b150 100644 --- a/tests/workflows/integration_tests/execution/test_workflow_with_clip.py +++ b/tests/workflows/integration_tests/execution/test_workflow_with_clip.py @@ -4,13 +4,339 @@ from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS from inference.core.managers.base import ModelManager from inference.core.workflows.core_steps.common.entities import StepExecutionMode -from inference.core.workflows.errors import RuntimeInputError +from inference.core.workflows.errors import RuntimeInputError, StepExecutionError from inference.core.workflows.execution_engine.core import ExecutionEngine from tests.workflows.integration_tests.execution.workflows_gallery_collector.decorators import ( add_to_workflows_gallery, ) CLIP_WORKFLOW = { + "version": "1.0", + "inputs": [ + {"type": "InferenceImage", "name": "image_1"}, + {"type": "InferenceImage", "name": "image_2"}, + ], + "steps": [ + { + "type": "roboflow_core/clip@v1", + "name": "embedding_1", + "data": "$inputs.image_1", + "version": "RN50", + }, + { + "type": "roboflow_core/clip@v1", + "name": "embedding_2", + "data": "$inputs.image_2", + "version": "RN50", + }, + { + "type": "roboflow_core/cosine_similarity@v1", + "name": "cosine_similarity", + "embedding_1": "$steps.embedding_1.embedding", + "embedding_2": "$steps.embedding_2.embedding", + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "similarity", + "coordinates_system": "own", + "selector": "$steps.cosine_similarity.similarity", + }, + { + "type": "JsonField", + "name": "image_embeddings", + "coordinates_system": "own", + "selector": "$steps.embedding_1.embedding", + }, + ], +} + + +@add_to_workflows_gallery( + category="Basic Workflows", + use_case_title="Workflow with Embeddings", + use_case_description=""" +This Workflow shows how to use an embedding model to compare the +similarity of two images with each other. + """, + workflow_definition=CLIP_WORKFLOW, + workflow_name_in_app="clip", +) +def test_clip_embedding_model( + model_manager: ModelManager, + license_plate_image: np.ndarray, + crowd_image: np.ndarray, +) -> None: + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.api_key": None, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=CLIP_WORKFLOW, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={"image_1": license_plate_image, "image_2": crowd_image} + ) + + # then + assert isinstance(result, list), "Expected list to be delivered" + assert len(result) == 1, "Expected 1 element in the output" + assert set(result[0].keys()) == { + "similarity", + "image_embeddings", + }, "Expected all declared outputs to be delivered" + assert ( + pytest.approx(result[0]["similarity"], 0.01) == 0.444 + ), "Expected similarity to be approximately the defined value" + assert ( + len(result[0]["image_embeddings"]) == 1024 + ), "Expected image embedding to be of dimension 1024 for RN50 model" + + +CLIP_WORKFLOW_COSINE_SIMILARITY_CROSS_DATA_TYPE = { + "version": "1.0", + "inputs": [ + {"type": "InferenceImage", "name": "image_1"}, + {"type": "WorkflowParameter", "name": "reference"}, + ], + "steps": [ + { + "type": "roboflow_core/clip@v1", + "name": "embedding_1", + "data": "$inputs.image_1", + "version": "RN50", + }, + { + "type": "roboflow_core/clip@v1", + "name": "embedding_2", + "data": "$inputs.reference", + "version": "RN50", + }, + { + "type": "roboflow_core/cosine_similarity@v1", + "name": "cosine_similarity", + "embedding_1": "$steps.embedding_1.embedding", + "embedding_2": "$steps.embedding_2.embedding", + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "similarity", + "coordinates_system": "own", + "selector": "$steps.cosine_similarity.similarity", + }, + { + "type": "JsonField", + "name": "image_embeddings", + "coordinates_system": "own", + "selector": "$steps.embedding_1.embedding", + }, + ], +} + + +def test_clip_embedding_model_on_batches_of_cross_type_data( + model_manager: ModelManager, + license_plate_image: np.ndarray, + crowd_image: np.ndarray, +) -> None: + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.api_key": None, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=CLIP_WORKFLOW_COSINE_SIMILARITY_CROSS_DATA_TYPE, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={ + "image_1": [license_plate_image, crowd_image], + "reference": "people", + } + ) + + # then + assert isinstance(result, list), "Expected list to be delivered" + assert len(result) == 2, "Expected 2 elements in the output" + assert set(result[0].keys()) == { + "similarity", + "image_embeddings", + }, "Expected all declared outputs to be delivered" + assert set(result[1].keys()) == { + "similarity", + "image_embeddings", + }, "Expected all declared outputs to be delivered" + assert ( + abs(result[0]["similarity"] - 0.13) < 0.02 + ), "Expected similarity to be approximately the defined value" + assert ( + len(result[0]["image_embeddings"]) == 1024 + ), "Expected image embedding to be of dimension 1024 for RN50 model" + assert ( + abs(result[1]["similarity"] - 0.15) < 0.02 + ), "Expected similarity to be approximately the defined value" + assert ( + len(result[1]["image_embeddings"]) == 1024 + ), "Expected image embedding to be of dimension 1024 for RN50 model" + + +CLIP_WORKFLOW_COSINE_SIMILARITY_CROSS_DATA_TYPE_WITH_INVALID_LENGTH_OF_EMBEDDINGS = { + "version": "1.0", + "inputs": [ + {"type": "InferenceImage", "name": "image_1"}, + {"type": "WorkflowParameter", "name": "reference"}, + ], + "steps": [ + { + "type": "roboflow_core/clip@v1", + "name": "embedding_1", + "data": "$inputs.image_1", + "version": "RN50", + }, + { + "type": "roboflow_core/clip@v1", + "name": "embedding_2", + "data": "$inputs.reference", + "version": "RN50x4", + }, + { + "type": "roboflow_core/cosine_similarity@v1", + "name": "cosine_similarity", + "embedding_1": "$steps.embedding_1.embedding", + "embedding_2": "$steps.embedding_2.embedding", + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "similarity", + "coordinates_system": "own", + "selector": "$steps.cosine_similarity.similarity", + }, + { + "type": "JsonField", + "name": "image_embeddings", + "coordinates_system": "own", + "selector": "$steps.embedding_1.embedding", + }, + ], +} + + +def test_clip_embedding_model_on_batches_of_cross_type_data_with_different_embeddings_length( + model_manager: ModelManager, + license_plate_image: np.ndarray, + crowd_image: np.ndarray, +) -> None: + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.api_key": None, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=CLIP_WORKFLOW_COSINE_SIMILARITY_CROSS_DATA_TYPE_WITH_INVALID_LENGTH_OF_EMBEDDINGS, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + with pytest.raises(StepExecutionError) as error: + _ = execution_engine.run( + runtime_parameters={ + "image_1": [license_plate_image, crowd_image], + "reference": "people", + } + ) + + # then + assert ( + "roboflow_core/cosine_similarity@v1 block feed with different shape of embeddings" + in str(error.value) + ) + + +CLIP_TEXT_WORKFLOW = { + "version": "1.0", + "inputs": [ + {"type": "WorkflowParameter", "name": "prompt"}, + ], + "steps": [ + { + "type": "roboflow_core/clip@v1", + "name": "embedding", + "data": "$inputs.prompt", + "version": "RN50", + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "text_embeddings", + "coordinates_system": "own", + "selector": "$steps.embedding.embedding", + }, + ], +} + + +def test_clip_text_embedding_model( + model_manager: ModelManager, + license_plate_image: np.ndarray, + crowd_image: np.ndarray, +) -> None: + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.api_key": None, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=CLIP_TEXT_WORKFLOW, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run(runtime_parameters={"prompt": "Foo Bar"}) + + # then + assert isinstance(result, list), "Expected list to be delivered" + assert len(result) == 1, "Expected 1 element in the output" + assert set(result[0].keys()) == { + "text_embeddings", + }, "Expected all declared outputs to be delivered" + assert ( + len(result[0]["text_embeddings"]) == 1024 + ), "Expected text embedding to be of dimension 1024 for RN50 model" + assert ( + pytest.approx(np.mean(result[0]["text_embeddings"]), 0.0001) == -0.016772 + ), "Expected embedding to have a value similar to during testing" + assert ( + pytest.approx(np.max(result[0]["text_embeddings"]), 0.0001) == 1.65736556 + ), "Expected embedding to have a value similar to during testing" + assert ( + pytest.approx(np.min(result[0]["text_embeddings"]), 0.0001) == -10.109556 + ), "Expected embedding to have a value similar to during testing" + assert ( + pytest.approx(np.std(result[0]["text_embeddings"]), 0.0001) == 0.39733439 + ), "Expected embedding to have a value similar to during testing" + + +CLIP_COMPARISON_WORKFLOW = { "version": "1.0", "inputs": [ {"type": "WorkflowImage", "name": "image"}, @@ -36,16 +362,16 @@ @add_to_workflows_gallery( category="Basic Workflows", - use_case_title="Workflow with CLIP model", + use_case_title="Workflow with CLIP Comparison", use_case_description=""" -This is the basic workflow that only contains a single CLIP model block. +This is the basic workflow that only contains a single CLIP Comparison block. Please take a look at how batch-oriented WorkflowImage data is plugged to detection step via input selector (`$inputs.image`) and how non-batch parameters (reference set of texts that the each image in batch will be compared to) is dynamically specified - via `$inputs.reference` selector. """, - workflow_definition=CLIP_WORKFLOW, + workflow_definition=CLIP_COMPARISON_WORKFLOW, workflow_name_in_app="clip", ) def test_clip_workflow_when_minimal_valid_input_provided( @@ -60,7 +386,7 @@ def test_clip_workflow_when_minimal_valid_input_provided( "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, } execution_engine = ExecutionEngine.init( - workflow_definition=CLIP_WORKFLOW, + workflow_definition=CLIP_COMPARISON_WORKFLOW, init_parameters=workflow_init_parameters, max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, ) diff --git a/tests/workflows/unit_tests/core_steps/math/test_cosine_similarity.py b/tests/workflows/unit_tests/core_steps/math/test_cosine_similarity.py new file mode 100644 index 000000000..8e9d780e9 --- /dev/null +++ b/tests/workflows/unit_tests/core_steps/math/test_cosine_similarity.py @@ -0,0 +1,82 @@ +import pytest +from pydantic import ValidationError + +from inference.core.workflows.core_steps.math.cosine_similarity.v1 import ( + BlockManifest, + CosineSimilarityBlockV1, +) + + +def test_manifest_parsing_when_data_is_valid(): + # Given + data = { + "type": "roboflow_core/cosine_similarity@v1", + "name": "cosine_step", + "embedding_1": "$steps.clip_image.embedding", + "embedding_2": "$steps.clip_text.embedding", + } + + # When + result = BlockManifest.model_validate(data) + + # Then + assert result.type == "roboflow_core/cosine_similarity@v1" + assert result.name == "cosine_step" + assert result.embedding_1 == "$steps.clip_image.embedding" + assert result.embedding_2 == "$steps.clip_text.embedding" + + +def test_manifest_parsing_when_data_is_invalid(): + # Given invalid data (not a valid embedding selector) + data = { + "type": "roboflow_core/cosine_similarity@v1", + "name": "cosine_step", + "embedding_1": "invalid_data", + "embedding_2": "invalid_data", + } + + # When / Then + with pytest.raises(ValidationError): + BlockManifest.model_validate(data) + + +def test_cosine_similarity_block_run_identical_embeddings(): + # Given identical embeddings + block = CosineSimilarityBlockV1() + embedding_1 = [0.1, 0.3, 0.5] + embedding_2 = [0.1, 0.3, 0.5] + + # When + result = block.run(embedding_1=embedding_1, embedding_2=embedding_2) + + # Then + # Cosine similarity should be close to 1.0 for identical vectors + assert pytest.approx(result["similarity"], 0.0001) == 1.0 + + +def test_cosine_similarity_block_run_orthogonal_embeddings(): + # Given orthogonal embeddings + block = CosineSimilarityBlockV1() + embedding_1 = [1.0, 0.0, 0.0] + embedding_2 = [0.0, 1.0, 0.0] + + # When + result = block.run(embedding_1=embedding_1, embedding_2=embedding_2) + + # Then + # Cosine similarity should be close to 0.0 for orthogonal vectors + assert pytest.approx(result["similarity"], 0.0001) == 0.0 + + +def test_cosine_similarity_block_run_negative_correlation(): + # Given inversely correlated embeddings + block = CosineSimilarityBlockV1() + embedding_1 = [1.0, 1.0, 1.0] + embedding_2 = [-1.0, -1.0, -1.0] + + # When + result = block.run(embedding_1=embedding_1, embedding_2=embedding_2) + + # Then + # Cosine similarity should be close to -1.0 for perfectly negatively correlated vectors + assert pytest.approx(result["similarity"], 0.0001) == -1.0 diff --git a/tests/workflows/unit_tests/core_steps/models/foundation/test_clip.py b/tests/workflows/unit_tests/core_steps/models/foundation/test_clip.py new file mode 100644 index 000000000..6ee576a7f --- /dev/null +++ b/tests/workflows/unit_tests/core_steps/models/foundation/test_clip.py @@ -0,0 +1,149 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from pydantic import ValidationError + +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.models.foundation.clip.v1 import ( + BlockManifest, + ClipModelBlockV1, +) +from inference.core.workflows.execution_engine.entities.base import ( + ImageParentMetadata, + WorkflowImageData, +) + + +@pytest.fixture +def mock_model_manager(): + # Mock a model manager that returns a predictable embedding + mock = MagicMock() + mock.infer_from_request_sync.return_value = MagicMock( + embeddings=[[0.1, 0.2, 0.3]] # Sample embedding + ) + return mock + + +@pytest.fixture +def mock_workflow_image_data(): + # Create a mock WorkflowImageData instance + start_image = np.random.randint(0, 255, (1000, 1000, 3), dtype=np.uint8) + return WorkflowImageData( + parent_metadata=ImageParentMetadata(parent_id="some"), + numpy_image=start_image, + ) + + +def test_manifest_parsing_valid(): + data = { + "type": "roboflow_core/clip@v1", + "name": "my_clip_step", + "data": "$inputs.image", + "version": "RN50", + } + + result = BlockManifest.model_validate(data) + assert result.type == "roboflow_core/clip@v1" + assert result.name == "my_clip_step" + assert result.data == "$inputs.image" + assert result.version == "RN50" + + +def test_manifest_parsing_invalid_missing_type(): + data = { + "name": "my_clip_step", + "data": "$inputs.image", + "version": "RN50", + } + with pytest.raises(ValidationError): + BlockManifest.model_validate(data) + + +def test_manifest_parsing_invalid_data_type(): + data = { + "type": "roboflow_core/clip@v1", + "name": "my_clip_step", + "data": 123, # invalid type + "version": "RN50", + } + with pytest.raises(ValidationError): + BlockManifest.model_validate(data) + + +def test_run_locally_with_text(mock_model_manager): + block = ClipModelBlockV1( + model_manager=mock_model_manager, + api_key=None, + step_execution_mode=StepExecutionMode.LOCAL, + ) + + # Run with text input + result = block.run(data="Hello world", version="RN50") + + assert isinstance(result, dict) + assert len(result["embedding"]) == 3 + assert result["embedding"] == [0.1, 0.2, 0.3] + mock_model_manager.infer_from_request_sync.assert_called_once() + + +def test_run_locally_with_image(mock_model_manager, mock_workflow_image_data): + block = ClipModelBlockV1( + model_manager=mock_model_manager, + api_key=None, + step_execution_mode=StepExecutionMode.LOCAL, + ) + + result = block.run(data=mock_workflow_image_data, version="RN50") + + assert isinstance(result, dict) + assert len(result["embedding"]) == 3 + assert result["embedding"] == [0.1, 0.2, 0.3] + mock_model_manager.infer_from_request_sync.assert_called_once() + + +@patch( + "inference.core.workflows.core_steps.models.foundation.clip.v1.InferenceHTTPClient" +) +def test_run_remotely_with_text(mock_client_cls, mock_model_manager): + # Mock the remote client and its return value + mock_client = MagicMock() + mock_client.get_clip_text_embeddings.return_value = { + "embeddings": [[0.1, 0.2, 0.3]] + } + mock_client_cls.return_value = mock_client + + block = ClipModelBlockV1( + model_manager=mock_model_manager, + api_key=None, + step_execution_mode=StepExecutionMode.REMOTE, + ) + + result = block.run(data="Hello world", version="RN50") + + assert result["embedding"] == [0.1, 0.2, 0.3] + mock_client.get_clip_text_embeddings.assert_called_once() + + +@patch( + "inference.core.workflows.core_steps.models.foundation.clip.v1.InferenceHTTPClient" +) +def test_run_remotely_with_image( + mock_client_cls, mock_model_manager, mock_workflow_image_data +): + mock_client = MagicMock() + mock_client.get_clip_image_embeddings.return_value = { + "embeddings": [[0.1, 0.2, 0.3]] + } + mock_client_cls.return_value = mock_client + + block = ClipModelBlockV1( + model_manager=mock_model_manager, + api_key=None, + step_execution_mode=StepExecutionMode.REMOTE, + ) + + result = block.run(data=mock_workflow_image_data, version="RN50") + + assert result["embedding"] == [0.1, 0.2, 0.3] + mock_client.get_clip_image_embeddings.assert_called_once()