Skip to content

Commit

Permalink
refactor(core): split polling can_read and reading from USB
Browse files Browse the repository at this point in the history
[no changelog]
  • Loading branch information
TychoVrahe committed Jan 13, 2025
1 parent 31fb952 commit e4f4985
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 50 deletions.
2 changes: 2 additions & 0 deletions core/embed/io/usb/inc/io/usb.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <io/usb_vcp.h>
#include <io/usb_webusb.h>

#define USB_PACKET_LEN 64

// clang-format off
//
// USB stack high-level state machine
Expand Down
66 changes: 50 additions & 16 deletions core/embed/io/usb/unix/usb.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ static struct {
int sock;
struct sockaddr_in si_me, si_other;
socklen_t slen;
uint8_t msg[64];
int msg_len;
} usb_ifaces[USBD_MAX_NUM_INTERFACES];

secbool usb_init(const usb_dev_info_t *dev_info) {
Expand All @@ -60,7 +62,9 @@ secbool usb_init(const usb_dev_info_t *dev_info) {
usb_ifaces[i].sock = -1;
memzero(&usb_ifaces[i].si_me, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].si_other, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].msg, sizeof(usb_ifaces[i].msg));
usb_ifaces[i].slen = 0;
usb_ifaces[i].msg_len = 0;
}
return sectrue;
}
Expand Down Expand Up @@ -136,36 +140,66 @@ secbool usb_vcp_add(const usb_vcp_info_t *info) {
return sectrue;
}

static secbool usb_emulated_poll(uint8_t iface_num, short dir) {
static secbool usb_emulated_poll_read(uint8_t iface_num) {
struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, dir, 0},
{usb_ifaces[iface_num].sock, POLLIN, 0},
};
int r = poll(fds, 1, 0);
return sectrue * (r > 0);
}
int res = poll(fds, 1, 0);

if (res <= 0) {
return secfalse;
}

static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
struct sockaddr_in si;
socklen_t sl = sizeof(si);
ssize_t r = recvfrom(usb_ifaces[iface_num].sock, buf, len, MSG_DONTWAIT,
ssize_t r = recvfrom(usb_ifaces[iface_num].sock, usb_ifaces[iface_num].msg,
sizeof(usb_ifaces[iface_num].msg), MSG_DONTWAIT,
(struct sockaddr *)&si, &sl);
if (r < 0) {
return r;
if (r <= 0) {
return secfalse;
}

usb_ifaces[iface_num].si_other = si;
usb_ifaces[iface_num].slen = sl;
static const char *ping_req = "PINGPING";
static const char *ping_resp = "PONGPONG";
if (r == strlen(ping_req) && 0 == memcmp(ping_req, buf, strlen(ping_req))) {
if (r == strlen(ping_req) &&
0 == memcmp(ping_req, usb_ifaces[iface_num].msg, strlen(ping_req))) {
if (usb_ifaces[iface_num].slen > 0) {
sendto(usb_ifaces[iface_num].sock, ping_resp, strlen(ping_resp),
MSG_DONTWAIT,
(const struct sockaddr *)&usb_ifaces[iface_num].si_other,
usb_ifaces[iface_num].slen);
}
return 0;
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
return secfalse;
}
return r;

usb_ifaces[iface_num].msg_len = r;

return sectrue;
}

static secbool usb_emulated_poll_write(uint8_t iface_num) {
struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, POLLOUT, 0},
};
int r = poll(fds, 1, 0);
return sectrue * (r > 0);
}

static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
if (usb_ifaces[iface_num].msg_len > 0) {
if (usb_ifaces[iface_num].msg_len < len) {
len = usb_ifaces[iface_num].msg_len;
}
memcpy(buf, usb_ifaces[iface_num].msg, len);
usb_ifaces[iface_num].msg_len = 0;
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
return len;
}

return 0;
}

static int usb_emulated_write(uint8_t iface_num, const uint8_t *buf,
Expand All @@ -184,31 +218,31 @@ secbool usb_hid_can_read(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLIN);
return usb_emulated_poll_read(iface_num);
}

secbool usb_webusb_can_read(uint8_t iface_num) {
if (iface_num >= USBD_MAX_NUM_INTERFACES ||
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLIN);
return usb_emulated_poll_read(iface_num);
}

secbool usb_hid_can_write(uint8_t iface_num) {
if (iface_num >= USBD_MAX_NUM_INTERFACES ||
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLOUT);
return usb_emulated_poll_write(iface_num);
}

secbool usb_webusb_can_write(uint8_t iface_num) {
if (iface_num >= USBD_MAX_NUM_INTERFACES ||
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLOUT);
return usb_emulated_poll_write(iface_num);
}

int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
Expand Down
45 changes: 45 additions & 0 deletions core/embed/upymod/modtrezorio/modtrezorio-hid.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,46 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj,
mod_trezorio_HID_write);

/// def read(self, buf: bytearray, offset: int = 0) -> int:
/// """
/// Reads message using USB HID (device) or UDP (emulator).
/// """
STATIC mp_obj_t mod_trezorio_HID_read(size_t n_args, const mp_obj_t *args) {
mp_obj_HID_t *o = MP_OBJ_TO_PTR(args[0]);
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);

int offset = 0;
if (n_args >= 2) {
offset = mp_obj_get_int(args[2]);
}

if (offset < 0) {
mp_raise_ValueError("Negative offset not allowed");
}

if (offset > buf.len) {
mp_raise_ValueError("Offset out of bounds");
}

uint32_t buffer_space = buf.len - offset;

if (buffer_space < USB_PACKET_LEN) {
mp_raise_ValueError("Buffer too small");
}

ssize_t r = usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset],
USB_PACKET_LEN);

if (r != USB_PACKET_LEN) {
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
}

return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_HID_read_obj, 2, 3,
mod_trezorio_HID_read);

/// def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
/// """
/// Sends message using USB HID (device) or UDP (emulator).
Expand All @@ -158,12 +198,17 @@ STATIC mp_obj_t mod_trezorio_HID_write_blocking(mp_obj_t self, mp_obj_t msg,
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorio_HID_write_blocking_obj,
mod_trezorio_HID_write_blocking);

/// PACKET_LEN: ClassVar[int]
/// """Length of one USB packet."""

STATIC const mp_rom_map_elem_t mod_trezorio_HID_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_HID_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_HID_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_HID_read_obj)},
{MP_ROM_QSTR(MP_QSTR_write_blocking),
MP_ROM_PTR(&mod_trezorio_HID_write_blocking_obj)},
{MP_ROM_QSTR(MP_QSTR_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_HID_locals_dict,
mod_trezorio_HID_locals_dict_table);
Expand Down
28 changes: 7 additions & 21 deletions core/embed/upymod/modtrezorio/modtrezorio-poll.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,29 +166,15 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref,
}
#endif
else if (mode == POLL_READ) {
if (sectrue == usb_hid_can_read(iface)) {
uint8_t buf[64] = {0};
int len = usb_hid_read(iface, buf, sizeof(buf));
if (len > 0) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_obj_new_bytes(buf, len);
return mp_const_true;
}
} else if (sectrue == usb_webusb_can_read(iface)) {
uint8_t buf[64] = {0};
int len = usb_webusb_read(iface, buf, sizeof(buf));
if (len > 0) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_obj_new_bytes(buf, len);
return mp_const_true;
}
}
} else if (mode == POLL_WRITE) {
if (sectrue == usb_hid_can_write(iface)) {
if ((sectrue == usb_hid_can_read(iface)) ||
(sectrue == usb_webusb_can_read(iface))) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_const_none;
ret->items[1] = MP_OBJ_NEW_SMALL_INT(USB_PACKET_LEN);
return mp_const_true;
} else if (sectrue == usb_webusb_can_write(iface)) {
}
} else if (mode == POLL_WRITE) {
if ((sectrue == usb_hid_can_write(iface)) ||
(sectrue == usb_webusb_can_write(iface))) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_const_none;
return mp_const_true;
Expand Down
45 changes: 45 additions & 0 deletions core/embed/upymod/modtrezorio/modtrezorio-webusb.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,55 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj,
mod_trezorio_WebUSB_write);

/// def read(self, buf: bytearray, offset: int = 0) -> int:
/// """
/// Reads message using USB WebUSB (device) or UDP (emulator).
/// """
STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) {
mp_obj_WebUSB_t *o = MP_OBJ_TO_PTR(args[0]);
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);

int offset = 0;
if (n_args >= 2) {
offset = mp_obj_get_int(args[2]);
}

if (offset < 0) {
mp_raise_ValueError("Negative offset not allowed");
}

if (offset > buf.len) {
mp_raise_ValueError("Offset out of bounds");
}

uint32_t buffer_space = buf.len - offset;

if (buffer_space < USB_PACKET_LEN) {
mp_raise_ValueError("Buffer too small");
}

ssize_t r = usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset],
USB_PACKET_LEN);

if (r != USB_PACKET_LEN) {
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
}

return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 2, 3,
mod_trezorio_WebUSB_read);

/// PACKET_LEN: ClassVar[int]
/// """Length of one USB packet."""

STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_WebUSB_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_WebUSB_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_WebUSB_read_obj)},
{MP_ROM_QSTR(MP_QSTR_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_WebUSB_locals_dict,
mod_trezorio_WebUSB_locals_dict_table);
Expand Down
14 changes: 14 additions & 0 deletions core/mocks/generated/trezorio/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ class HID:
Sends message using USB HID (device) or UDP (emulator).
"""

def read(self, buf: bytearray, offset: int = 0) -> int:
"""
Reads message using USB HID (device) or UDP (emulator).
"""

def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
"""
Sends message using USB HID (device) or UDP (emulator).
"""
PACKET_LEN: ClassVar[int]
"""Length of one USB packet."""


# upymod/modtrezorio/modtrezorio-poll.h
Expand Down Expand Up @@ -148,6 +155,13 @@ class WebUSB:
"""
Sends message using USB WebUSB (device) or UDP (emulator).
"""

def read(self, buf: bytearray, offset: int = 0) -> int:
"""
Reads message using USB WebUSB (device) or UDP (emulator).
"""
PACKET_LEN: ClassVar[int]
"""Length of one USB packet."""
from . import fatfs, haptic, sdcard
POLL_READ: int # wait until interface is readable and return read data
POLL_WRITE: int # wait until interface is writable
Expand Down
8 changes: 6 additions & 2 deletions core/src/apps/webauthn/fido2.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ async def _read_cmd(iface: HID) -> Cmd | None:
read = loop.wait(iface.iface_num() | io.POLL_READ)

# wait for incoming command indefinitely
buf = await read
msg_len = await read
buf = bytearray(msg_len)
iface.read(buf, 0)
while True:
ifrm = overlay_struct(bytearray(buf), desc_init)
bcnt = ifrm.bcnt
Expand Down Expand Up @@ -415,7 +417,9 @@ async def _read_cmd(iface: HID) -> Cmd | None:
read.timeout_ms = _CTAP_HID_TIMEOUT_MS
while datalen < bcnt:
try:
buf = await read
msg_len = await read
buf = bytearray(msg_len)
iface.read(buf, 0)
except loop.Timeout:
if __debug__:
warning(__name__, "_ERR_MSG_TIMEOUT")
Expand Down
11 changes: 8 additions & 3 deletions core/src/trezor/wire/codec/codec_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if TYPE_CHECKING:
from trezorio import WireInterface

_REP_LEN = const(64)
_REP_LEN = io.WebUSB.PACKET_LEN

_REP_MARKER = const(63) # ord('?')
_REP_MAGIC = const(35) # org('#')
Expand All @@ -23,9 +23,12 @@ class CodecError(WireError):

async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)
report = bytearray(_REP_LEN)

# wait for initial report
report = await read
msg_len = await read
assert msg_len == len(report)
iface.read(report, 0)
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
Expand All @@ -50,7 +53,9 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag

while nread < msize:
# wait for continuation report
report = await read
msg_len = await read
assert msg_len == len(report)
iface.read(report, 0)
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")

Expand Down
Loading

0 comments on commit e4f4985

Please sign in to comment.