Skip to content

Commit

Permalink
Buffer messages to wait for reconnect
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas ONeil <[email protected]>
  • Loading branch information
loneil committed Dec 20, 2024
1 parent f6c4c6c commit 400aad2
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 36 deletions.
17 changes: 5 additions & 12 deletions oidc-controller/api/routers/acapy_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..db.session import get_db

from ..core.config import settings
from ..routers.socketio import sio, connections_reload
from ..routers.socketio import buffered_emit, connections_reload

logger: structlog.typing.FilteringBoundLogger = structlog.getLogger(__name__)

Expand Down Expand Up @@ -39,9 +39,6 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db

# Get the saved websocket session
pid = str(auth_session.id)
connections = connections_reload()
sid = connections.get(pid)
logger.debug(f"sid: {sid} found for pid: {pid}")

if webhook_body["state"] == "presentation-received":
logger.info("presentation-received")
Expand All @@ -51,12 +48,10 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
if webhook_body["verified"] == "true":
auth_session.proof_status = AuthSessionState.VERIFIED
auth_session.presentation_exchange = webhook_body["by_format"]
if sid:
await sio.emit("status", {"status": "verified"}, to=sid)
await buffered_emit("status", {"status": "verified"}, to_pid=pid)
else:
auth_session.proof_status = AuthSessionState.FAILED
if sid:
await sio.emit("status", {"status": "failed"}, to=sid)
await buffered_emit("status", {"status": "failed"}, to_pid=pid)

await AuthSessionCRUD(db).patch(
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
Expand All @@ -67,8 +62,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
logger.info("ABANDONED")
logger.info(webhook_body["error_msg"])
auth_session.proof_status = AuthSessionState.ABANDONED
if sid:
await sio.emit("status", {"status": "abandoned"}, to=sid)
await buffered_emit("status", {"status": "abandoned"}, to_pid=pid)

await AuthSessionCRUD(db).patch(
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
Expand All @@ -93,8 +87,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
):
logger.info("EXPIRED")
auth_session.proof_status = AuthSessionState.EXPIRED
if sid:
await sio.emit("status", {"status": "expired"}, to=sid)
await buffered_emit("status", {"status": "expired"}, to_pid=pid)

await AuthSessionCRUD(db).patch(
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
Expand Down
7 changes: 2 additions & 5 deletions oidc-controller/api/routers/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..db.session import get_db

# Access to the websocket
from ..routers.socketio import connections_reload, sio
from ..routers.socketio import buffered_emit, connections_reload

from ..verificationConfigs.crud import VerificationConfigCRUD
from ..verificationConfigs.helpers import VariableSubstitutionError
Expand All @@ -58,8 +58,6 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)):
auth_session = await AuthSessionCRUD(db).get(pid)

pid = str(auth_session.id)
connections = connections_reload()
sid = connections.get(pid)

"""
Check if proof is expired. But only if the proof has not been started.
Expand All @@ -75,8 +73,7 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)):
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
)
# Send message through the websocket.
if sid:
await sio.emit("status", {"status": "expired"}, to=sid)
await buffered_emit("status", {"status": "expired"}, to_pid=pid)

return {"proof_status": auth_session.proof_status}

Expand Down
9 changes: 2 additions & 7 deletions oidc-controller/api/routers/presentation_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..authSessions.models import AuthSession, AuthSessionState

from ..core.config import settings
from ..routers.socketio import sio, connections_reload
from ..routers.socketio import buffered_emit, connections_reload
from ..routers.oidc import gen_deep_link
from ..db.session import get_db

Expand Down Expand Up @@ -49,16 +49,11 @@ async def send_connectionless_proof_req(
pres_exch_id
)

# Get the websocket session
connections = connections_reload()
sid = connections.get(str(auth_session.id))

# If the qrcode has been scanned, toggle the verified flag
if auth_session.proof_status is AuthSessionState.NOT_STARTED:
auth_session.proof_status = AuthSessionState.PENDING
await AuthSessionCRUD(db).patch(auth_session.id, auth_session)
if sid:
await sio.emit("status", {"status": "pending"}, to=sid)
await buffered_emit("status", {"status": "pending"}, to_pid=auth_session.id)

msg = auth_session.presentation_request_msg

Expand Down
72 changes: 60 additions & 12 deletions oidc-controller/api/routers/socketio.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,84 @@
import socketio # For using websockets
import logging
import time

logger = logging.getLogger(__name__)


connections = {}
message_buffers = {}
buffer_timeout = 60 # Timeout in seconds

sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")

sio_app = socketio.ASGIApp(socketio_server=sio, socketio_path="/ws/socket.io")


@sio.event
async def connect(sid, socket):
logger.info(f">>> connect : sid={sid}")


@sio.event
async def initialize(sid, data):
global connections
# Store websocket session matched to the presentation exchange id
connections[data.get("pid")] = sid

global connections, message_buffers
pid = data.get("pid")
connections[pid] = sid
# Initialize buffer if it doesn't exist
if pid not in message_buffers:
message_buffers[pid] = []

@sio.event
async def disconnect(sid):
global connections
global connections, message_buffers
logger.info(f">>> disconnect : sid={sid}")
# Remove websocket session from the store
if len(connections) > 0:
connections = {k: v for k, v in connections.items() if v != sid}
# Find the pid associated with the sid
pid = next((k for k, v in connections.items() if v == sid), None)
if pid:
# Remove pid from connections
del connections[pid]

async def buffered_emit(event, data, to_pid=None):
global connections, message_buffers

connections = connections_reload()
sid = connections.get(to_pid)
logger.debug(f"sid: {sid} found for pid: {to_pid}")

if sid:
try:
await sio.emit(event, data, room=sid)
except:
# If send fails, buffer the message
buffer_message(to_pid, event, data)
else:
# Buffer the message if the target is not connected
buffer_message(to_pid, event, data)

def buffer_message(pid, event, data):
global message_buffers
current_time = time.time()
if pid not in message_buffers:
message_buffers[pid] = []
# Add message with timestamp and event name
message_buffers[pid].append((event, data, current_time))
# Clean up old messages
message_buffers[pid] = [
(msg_event, msg_data, timestamp) for msg_event, msg_data, timestamp in message_buffers[pid]
if current_time - timestamp <= buffer_timeout
]

@sio.event
async def fetch_buffered_messages(sid, pid):
global message_buffers
current_time = time.time()
if pid in message_buffers:
# Filter messages that are still valid (i.e., within the buffer_timeout)
valid_messages = [
(msg_event, msg_data, timestamp) for msg_event, msg_data, timestamp in message_buffers[pid]
if current_time - timestamp <= buffer_timeout
]
# Emit each valid message
for event, data, _ in valid_messages:
await sio.emit(event, data, room=sid)
# Reassign the valid_messages back to message_buffers[pid] to clean up old messages
message_buffers[pid] = valid_messages

def connections_reload():
global connections
Expand Down
10 changes: 10 additions & 0 deletions oidc-controller/api/templates/verified_credentials.html
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ <h1 class="mb-3 fw-bolder fs-1">Continue with:</h1>
>
DEBUG Disconnect Web Socket
</button>

<button
class="btn btn-primary mt-4"
v-on:click="socket.connect()"
title="Reconnect Websocket"
>
DEBUG Reconnect Web Socket
</button>
</div>

<hr v-if="mobileDevice" />
Expand Down Expand Up @@ -383,6 +391,8 @@ <h5 v-if="state.showScanned" class="fw-bolder mb-3">
`Socket connecting. SID: ${this.socket.id}. PID: {{pid}}. Recovered? ${this.socket.recovered} `
);
this.socket.emit("initialize", { pid: "{{pid}}" });
// Emit the `fetch_buffered_messages` event with `pid` as a string using Jinja templating
this.socket.emit('fetch_buffered_messages', '{{ pid }}');
});

this.socket.on("connect_error", (error) => {
Expand Down

0 comments on commit 400aad2

Please sign in to comment.