Skip to content

Commit

Permalink
Merge pull request #930 from roboflow/mqtt-writer
Browse files Browse the repository at this point in the history
Added the MQTT Writer Block
  • Loading branch information
PawelPeczek-Roboflow authored Jan 10, 2025
2 parents d149fb7 + 9e29b0d commit 1b41056
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 0 deletions.
4 changes: 4 additions & 0 deletions inference/enterprise/workflows/enterprise_blocks/loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import List, Type

from inference.core.workflows.prototypes.block import WorkflowBlock
from inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1 import (
MQTTWriterSinkBlockV1,
)
from inference.enterprise.workflows.enterprise_blocks.sinks.opc_writer.v1 import (
OPCWriterSinkBlockV1,
)
Expand All @@ -12,5 +15,6 @@
def load_enterprise_blocks() -> List[Type[WorkflowBlock]]:
return [
OPCWriterSinkBlockV1,
MQTTWriterSinkBlockV1,
PLCBlockV1,
]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import List, Literal, Optional, Type, Union

import paho.mqtt.client as mqtt
from pydantic import ConfigDict, Field

from inference.core.logger import logger
from inference.core.workflows.execution_engine.entities.base import OutputDefinition
from inference.core.workflows.execution_engine.entities.types import (
BOOLEAN_KIND,
FLOAT_KIND,
INTEGER_KIND,
STRING_KIND,
Selector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
WorkflowBlock,
WorkflowBlockManifest,
)

LONG_DESCRIPTION = """
MQTT Writer block for publishing messages to an MQTT broker.
This block is blocking on connect and publish operations.
Outputs:
- error_status (bool): Indicates if an error occurred during the MQTT publishing process.
True if there was an error, False if successful.
- message (str): Status message describing the result of the operation.
Contains error details if error_status is True,
or success confirmation if error_status is False.
"""


class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
json_schema_extra={
"name": "MQTT Writer",
"version": "v1",
"short_description": "Publishes messages to an MQTT broker.",
"long_description": LONG_DESCRIPTION,
"license": "Roboflow Enterprise License",
"block_type": "sink",
}
)
type: Literal["mqtt_writer_sink@v1"]
host: Union[Selector(kind=[STRING_KIND]), str] = Field(
description="Host of the MQTT broker.",
examples=["localhost", "$inputs.mqtt_host"],
)
port: Union[int, Selector(kind=[INTEGER_KIND])] = Field(
description="Port of the MQTT broker.",
examples=[1883, "$inputs.mqtt_port"],
)
topic: Union[Selector(kind=[STRING_KIND]), str] = Field(
description="MQTT topic to publish the message to.",
examples=["sensors/temperature", "$inputs.mqtt_topic"],
)
message: Union[Selector(kind=[STRING_KIND]), str] = Field(
description="Message to be published.",
examples=["Hello, MQTT!", "$inputs.mqtt_message"],
)
qos: Union[int, Selector(kind=[INTEGER_KIND])] = Field(
default=0,
description="Quality of Service level for the message.",
examples=[0, 1, 2],
)
retain: Union[bool, Selector(kind=[BOOLEAN_KIND])] = Field(
default=False,
description="Whether the message should be retained by the broker.",
examples=[True, False],
)
timeout: Union[float, Selector(kind=[FLOAT_KIND])] = Field(
default=0.5,
description="Timeout for connecting to the MQTT broker and for sending MQTT messages.",
examples=[0.5],
)
username: Union[Selector(kind=[STRING_KIND]), str] = Field(
default=None,
description="Username for MQTT broker authentication.",
examples=["$inputs.mqtt_username"],
)
password: Union[Selector(kind=[STRING_KIND]), str] = Field(
default=None,
description="Password for MQTT broker authentication.",
examples=["$inputs.mqtt_password"],
)

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(name="error_status", kind=[BOOLEAN_KIND]),
OutputDefinition(name="message", kind=[STRING_KIND]),
]


class MQTTWriterSinkBlockV1(WorkflowBlock):
def __init__(self):
self.mqtt_client: Optional[mqtt.Client] = None

@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
return BlockManifest

def run(
self,
host: str,
port: int,
topic: str,
message: str,
username: Optional[str] = None,
password: Optional[str] = None,
qos: int = 0,
retain: bool = False,
timeout: float = 0.5,
) -> BlockResult:
if self.mqtt_client is None:
self.mqtt_client = mqtt.Client()
if username and password:
self.mqtt_client.username_pw_set(username, password)
self.mqtt_client.on_connect = self.mqtt_on_connect
self.mqtt_client.on_connect_fail = self.mqtt_on_connect_fail
self.mqtt_client.reconnect_delay_set(
min_delay=timeout, max_delay=2 * timeout
)
try:
# TODO: blocking, consider adding fire_and_forget like in OPC writer
self.mqtt_client.connect(host, port)
except Exception as e:
logger.error("Failed to connect to MQTT broker: %s", e)
return {
"error_status": True,
"message": f"Failed to connect to MQTT broker: {e}",
}

if not self.mqtt_client.is_connected():
try:
# TODO: blocking
self.mqtt_client.reconnect()
except Exception as e:
logger.error("Failed to connect to MQTT broker: %s", e)
return {
"error_status": True,
"message": f"Failed to connect to MQTT broker: {e}",
}

try:
res: mqtt.MQTTMessageInfo = self.mqtt_client.publish(
topic, message, qos=qos, retain=retain
)
# TODO: this is blocking
res.wait_for_publish(timeout=timeout)
if res.is_published():
return {
"error_status": False,
"message": "Message published successfully",
}
else:
return {"error_status": True, "message": "Failed to publish payload"}
except Exception as e:
logger.error("Failed to publish message: %s", e)
return {"error_status": True, "message": f"Unhandled error - {e}"}

def mqtt_on_connect(self, client, userdata, flags, reason_code, properties=None):
logger.info("Connected with result code %s", reason_code)

def mqtt_on_connect_fail(
self, client, userdata, flags, reason_code, properties=None
):
logger.error(f"Failed to connect with result code %s", reason_code)
1 change: 1 addition & 0 deletions requirements/_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ tldextract~=5.1.2
packaging~=24.0
anthropic~=0.34.2
pandas>=2.0.0,<2.3.0
paho-mqtt~=1.6.1
pytest>=8.0.0,<9.0.0 # this is not a joke, sam2 requires this as the fork we are using is dependent on that, yet
# do not mark the dependency: https://github.com/SauravMaheshkar/samv2/blob/main/sam2/utils/download.py
tokenizers>=0.19.0,<=0.20.3
Expand Down
61 changes: 61 additions & 0 deletions tests/workflows/integration_tests/execution/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import os.path
import socket
import tempfile
from typing import Generator

Expand Down Expand Up @@ -96,3 +98,62 @@ def bool_env(val):
@pytest.fixture(scope="function")
def face_image() -> np.ndarray:
return cv2.imread(os.path.join(ASSETS_DIR, "face.jpeg"))


# Below taken from https://github.com/eclipse-paho/paho.mqtt.python/blob/d45de3737879cfe7a6acc361631fa5cb1ef584bb/tests/testsupport/broker.py
class FakeMQTTBroker:
def __init__(self):
# Bind to "localhost" for maximum performance, as described in:
# http://docs.python.org/howto/sockets.html#ipc
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.host = "localhost"
sock.bind((self.host, 0))
self.port = sock.getsockname()[1]
self.messages = []
self.messages_count_to_wait_for = 2

sock.settimeout(5)
sock.listen(1)

self._sock = sock
self._conn = None

def start(self):
if self._sock is None:
raise ValueError("Socket is not open")
if self._conn is not None:
raise ValueError("Connection is already open")

while len(self.messages) < self.messages_count_to_wait_for:
(conn, address) = self._sock.accept()
conn.settimeout(1)
self._conn = conn
self.messages.append(self.receive_packet(1000))

def finish(self):
if self._conn is not None:
self._conn.close()
self._conn = None

if self._sock is not None:
self._sock.close()
self._sock = None

def receive_packet(self, num_bytes):
if self._conn is None:
raise ValueError("Connection is not open")

packet_in = self._conn.recv(num_bytes)
return packet_in


@pytest.fixture(scope="function")
def fake_mqtt_broker():
print("Setup broker")
broker = FakeMQTTBroker()

yield broker

print("Teardown broker")
broker.finish()
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest
import paho.mqtt.client as mqtt
from inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1 import MQTTWriterSinkBlockV1
import pytest
import threading

@pytest.mark.timeout(5)
def test_successful_connection_and_publishing(fake_mqtt_broker):
# given
block = MQTTWriterSinkBlockV1()
published_message = 'Test message'
expected_message = 'Message published successfully'

fake_mqtt_broker.messages_count_to_wait_for = 2
broker_thread = threading.Thread(target=fake_mqtt_broker.start)
broker_thread.start()

# when
result = block.run(
host=fake_mqtt_broker.host,
port=fake_mqtt_broker.port,
topic="RoboflowTopic",
message=published_message
)

broker_thread.join(timeout=2)

# then
assert result['error_status'] is False, "No error expected"
assert result['message'] == expected_message

assert published_message.encode() in fake_mqtt_broker.messages[-1]
76 changes: 76 additions & 0 deletions tests/workflows/unit_tests/core_steps/sinks/test_mqtt_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import unittest
from unittest.mock import patch, MagicMock
from inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1 import MQTTWriterSinkBlockV1

class TestMQTTWriterSinkBlockV1(unittest.TestCase):
@patch('inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1.mqtt.Client')
def test_successful_connection_and_publishing(self, MockMQTTClient):
# Arrange
mock_client = MockMQTTClient.return_value
mock_client.is_connected.return_value = False
mock_client.publish.return_value.is_published.return_value = True

block = MQTTWriterSinkBlockV1()

# Act
result = block.run(
host='localhost',
port=1883,
topic='test/topic',
message='Hello, MQTT!',
username='lenny',
password='roboflow'
)

# Assert
self.assertFalse(result['error_status'])
self.assertEqual(result['message'], 'Message published successfully')

@patch('inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1.mqtt.Client')
def test_connection_failure(self, MockMQTTClient):
# Arrange
mock_client = MockMQTTClient.return_value
mock_client.is_connected.return_value = False
mock_client.connect.side_effect = Exception('Connection failed')

block = MQTTWriterSinkBlockV1()

# Act
result = block.run(
host='localhost',
port=1883,
topic='test/topic',
message='Hello, MQTT!',
username='lenny',
password='roboflow'
)

# Assert
self.assertTrue(result['error_status'])
self.assertIn('Failed to connect to MQTT broker', result['message'])

@patch('inference.enterprise.workflows.enterprise_blocks.sinks.mqtt_writer.v1.mqtt.Client')
def test_publishing_failure(self, MockMQTTClient):
# Arrange
mock_client = MockMQTTClient.return_value
mock_client.is_connected.return_value = True
mock_client.publish.return_value.is_published.return_value = False

block = MQTTWriterSinkBlockV1()

# Act
result = block.run(
host='localhost',
port=1883,
topic='test/topic',
message='Hello, MQTT!',
username='lenny',
password='roboflow'
)

# Assert
self.assertTrue(result['error_status'])
self.assertEqual(result['message'], 'Failed to publish payload')

if __name__ == '__main__':
unittest.main()

0 comments on commit 1b41056

Please sign in to comment.