-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #930 from roboflow/mqtt-writer
Added the MQTT Writer Block
- Loading branch information
Showing
7 changed files
with
344 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
170 changes: 170 additions & 0 deletions
170
inference/enterprise/workflows/enterprise_blocks/sinks/mqtt_writer/v1.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from typing import List, Literal, Optional, Type, Union | ||
|
||
import paho.mqtt.client as mqtt | ||
from pydantic import ConfigDict, Field | ||
|
||
from inference.core.logger import logger | ||
from inference.core.workflows.execution_engine.entities.base import OutputDefinition | ||
from inference.core.workflows.execution_engine.entities.types import ( | ||
BOOLEAN_KIND, | ||
FLOAT_KIND, | ||
INTEGER_KIND, | ||
STRING_KIND, | ||
Selector, | ||
) | ||
from inference.core.workflows.prototypes.block import ( | ||
BlockResult, | ||
WorkflowBlock, | ||
WorkflowBlockManifest, | ||
) | ||
|
||
LONG_DESCRIPTION = """ | ||
MQTT Writer block for publishing messages to an MQTT broker. | ||
This block is blocking on connect and publish operations. | ||
Outputs: | ||
- error_status (bool): Indicates if an error occurred during the MQTT publishing process. | ||
True if there was an error, False if successful. | ||
- message (str): Status message describing the result of the operation. | ||
Contains error details if error_status is True, | ||
or success confirmation if error_status is False. | ||
""" | ||
|
||
|
||
class BlockManifest(WorkflowBlockManifest): | ||
model_config = ConfigDict( | ||
json_schema_extra={ | ||
"name": "MQTT Writer", | ||
"version": "v1", | ||
"short_description": "Publishes messages to an MQTT broker.", | ||
"long_description": LONG_DESCRIPTION, | ||
"license": "Roboflow Enterprise License", | ||
"block_type": "sink", | ||
} | ||
) | ||
type: Literal["mqtt_writer_sink@v1"] | ||
host: Union[Selector(kind=[STRING_KIND]), str] = Field( | ||
description="Host of the MQTT broker.", | ||
examples=["localhost", "$inputs.mqtt_host"], | ||
) | ||
port: Union[int, Selector(kind=[INTEGER_KIND])] = Field( | ||
description="Port of the MQTT broker.", | ||
examples=[1883, "$inputs.mqtt_port"], | ||
) | ||
topic: Union[Selector(kind=[STRING_KIND]), str] = Field( | ||
description="MQTT topic to publish the message to.", | ||
examples=["sensors/temperature", "$inputs.mqtt_topic"], | ||
) | ||
message: Union[Selector(kind=[STRING_KIND]), str] = Field( | ||
description="Message to be published.", | ||
examples=["Hello, MQTT!", "$inputs.mqtt_message"], | ||
) | ||
qos: Union[int, Selector(kind=[INTEGER_KIND])] = Field( | ||
default=0, | ||
description="Quality of Service level for the message.", | ||
examples=[0, 1, 2], | ||
) | ||
retain: Union[bool, Selector(kind=[BOOLEAN_KIND])] = Field( | ||
default=False, | ||
description="Whether the message should be retained by the broker.", | ||
examples=[True, False], | ||
) | ||
timeout: Union[float, Selector(kind=[FLOAT_KIND])] = Field( | ||
default=0.5, | ||
description="Timeout for connecting to the MQTT broker and for sending MQTT messages.", | ||
examples=[0.5], | ||
) | ||
username: Union[Selector(kind=[STRING_KIND]), str] = Field( | ||
default=None, | ||
description="Username for MQTT broker authentication.", | ||
examples=["$inputs.mqtt_username"], | ||
) | ||
password: Union[Selector(kind=[STRING_KIND]), str] = Field( | ||
default=None, | ||
description="Password for MQTT broker authentication.", | ||
examples=["$inputs.mqtt_password"], | ||
) | ||
|
||
@classmethod | ||
def describe_outputs(cls) -> List[OutputDefinition]: | ||
return [ | ||
OutputDefinition(name="error_status", kind=[BOOLEAN_KIND]), | ||
OutputDefinition(name="message", kind=[STRING_KIND]), | ||
] | ||
|
||
|
||
class MQTTWriterSinkBlockV1(WorkflowBlock): | ||
def __init__(self): | ||
self.mqtt_client: Optional[mqtt.Client] = None | ||
|
||
@classmethod | ||
def get_manifest(cls) -> Type[WorkflowBlockManifest]: | ||
return BlockManifest | ||
|
||
def run( | ||
self, | ||
host: str, | ||
port: int, | ||
topic: str, | ||
message: str, | ||
username: Optional[str] = None, | ||
password: Optional[str] = None, | ||
qos: int = 0, | ||
retain: bool = False, | ||
timeout: float = 0.5, | ||
) -> BlockResult: | ||
if self.mqtt_client is None: | ||
self.mqtt_client = mqtt.Client() | ||
if username and password: | ||
self.mqtt_client.username_pw_set(username, password) | ||
self.mqtt_client.on_connect = self.mqtt_on_connect | ||
self.mqtt_client.on_connect_fail = self.mqtt_on_connect_fail | ||
self.mqtt_client.reconnect_delay_set( | ||
min_delay=timeout, max_delay=2 * timeout | ||
) | ||
try: | ||
# TODO: blocking, consider adding fire_and_forget like in OPC writer | ||
self.mqtt_client.connect(host, port) | ||
except Exception as e: | ||
logger.error("Failed to connect to MQTT broker: %s", e) | ||
return { | ||
"error_status": True, | ||
"message": f"Failed to connect to MQTT broker: {e}", | ||
} | ||
|
||
if not self.mqtt_client.is_connected(): | ||
try: | ||
# TODO: blocking | ||
self.mqtt_client.reconnect() | ||
except Exception as e: | ||
logger.error("Failed to connect to MQTT broker: %s", e) | ||
return { | ||
"error_status": True, | ||
"message": f"Failed to connect to MQTT broker: {e}", | ||
} | ||
|
||
try: | ||
res: mqtt.MQTTMessageInfo = self.mqtt_client.publish( | ||
topic, message, qos=qos, retain=retain | ||
) | ||
# TODO: this is blocking | ||
res.wait_for_publish(timeout=timeout) | ||
if res.is_published(): | ||
return { | ||
"error_status": False, | ||
"message": "Message published successfully", | ||
} | ||
else: | ||
return {"error_status": True, "message": "Failed to publish payload"} | ||
except Exception as e: | ||
logger.error("Failed to publish message: %s", e) | ||
return {"error_status": True, "message": f"Unhandled error - {e}"} | ||
|
||
def mqtt_on_connect(self, client, userdata, flags, reason_code, properties=None): | ||
logger.info("Connected with result code %s", reason_code) | ||
|
||
def mqtt_on_connect_fail( | ||
self, client, userdata, flags, reason_code, properties=None | ||
): | ||
logger.error(f"Failed to connect with result code %s", reason_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
tests/workflows/integration_tests/execution/test_workflow_with_mqtt_writer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import unittest | ||
import paho.mqtt.client as mqtt | ||
from inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1 import MQTTWriterSinkBlockV1 | ||
import pytest | ||
import threading | ||
|
||
@pytest.mark.timeout(5) | ||
def test_successful_connection_and_publishing(fake_mqtt_broker): | ||
# given | ||
block = MQTTWriterSinkBlockV1() | ||
published_message = 'Test message' | ||
expected_message = 'Message published successfully' | ||
|
||
fake_mqtt_broker.messages_count_to_wait_for = 2 | ||
broker_thread = threading.Thread(target=fake_mqtt_broker.start) | ||
broker_thread.start() | ||
|
||
# when | ||
result = block.run( | ||
host=fake_mqtt_broker.host, | ||
port=fake_mqtt_broker.port, | ||
topic="RoboflowTopic", | ||
message=published_message | ||
) | ||
|
||
broker_thread.join(timeout=2) | ||
|
||
# then | ||
assert result['error_status'] is False, "No error expected" | ||
assert result['message'] == expected_message | ||
|
||
assert published_message.encode() in fake_mqtt_broker.messages[-1] |
76 changes: 76 additions & 0 deletions
76
tests/workflows/unit_tests/core_steps/sinks/test_mqtt_writer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
from inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1 import MQTTWriterSinkBlockV1 | ||
|
||
class TestMQTTWriterSinkBlockV1(unittest.TestCase): | ||
@patch('inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1.mqtt.Client') | ||
def test_successful_connection_and_publishing(self, MockMQTTClient): | ||
# Arrange | ||
mock_client = MockMQTTClient.return_value | ||
mock_client.is_connected.return_value = False | ||
mock_client.publish.return_value.is_published.return_value = True | ||
|
||
block = MQTTWriterSinkBlockV1() | ||
|
||
# Act | ||
result = block.run( | ||
host='localhost', | ||
port=1883, | ||
topic='test/topic', | ||
message='Hello, MQTT!', | ||
username='lenny', | ||
password='roboflow' | ||
) | ||
|
||
# Assert | ||
self.assertFalse(result['error_status']) | ||
self.assertEqual(result['message'], 'Message published successfully') | ||
|
||
@patch('inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1.mqtt.Client') | ||
def test_connection_failure(self, MockMQTTClient): | ||
# Arrange | ||
mock_client = MockMQTTClient.return_value | ||
mock_client.is_connected.return_value = False | ||
mock_client.connect.side_effect = Exception('Connection failed') | ||
|
||
block = MQTTWriterSinkBlockV1() | ||
|
||
# Act | ||
result = block.run( | ||
host='localhost', | ||
port=1883, | ||
topic='test/topic', | ||
message='Hello, MQTT!', | ||
username='lenny', | ||
password='roboflow' | ||
) | ||
|
||
# Assert | ||
self.assertTrue(result['error_status']) | ||
self.assertIn('Failed to connect to MQTT broker', result['message']) | ||
|
||
@patch('inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1.mqtt.Client') | ||
def test_publishing_failure(self, MockMQTTClient): | ||
# Arrange | ||
mock_client = MockMQTTClient.return_value | ||
mock_client.is_connected.return_value = True | ||
mock_client.publish.return_value.is_published.return_value = False | ||
|
||
block = MQTTWriterSinkBlockV1() | ||
|
||
# Act | ||
result = block.run( | ||
host='localhost', | ||
port=1883, | ||
topic='test/topic', | ||
message='Hello, MQTT!', | ||
username='lenny', | ||
password='roboflow' | ||
) | ||
|
||
# Assert | ||
self.assertTrue(result['error_status']) | ||
self.assertEqual(result['message'], 'Failed to publish payload') | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |