Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split polling can_read and reading from USB #4419

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/embed/io/usb/inc/io/usb.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <io/usb_vcp.h>
#include <io/usb_webusb.h>

#define USB_PACKET_LEN 64

// clang-format off
//
// USB stack high-level state machine
Expand Down
66 changes: 50 additions & 16 deletions core/embed/io/usb/unix/usb.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ static struct {
int sock;
struct sockaddr_in si_me, si_other;
socklen_t slen;
uint8_t msg[64];
int msg_len;
} usb_ifaces[USBD_MAX_NUM_INTERFACES];

secbool usb_init(const usb_dev_info_t *dev_info) {
Expand All @@ -60,7 +62,9 @@ secbool usb_init(const usb_dev_info_t *dev_info) {
usb_ifaces[i].sock = -1;
memzero(&usb_ifaces[i].si_me, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].si_other, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].msg, sizeof(usb_ifaces[i].msg));
usb_ifaces[i].slen = 0;
usb_ifaces[i].msg_len = 0;
}
return sectrue;
}
Expand Down Expand Up @@ -136,36 +140,66 @@ secbool usb_vcp_add(const usb_vcp_info_t *info) {
return sectrue;
}

static secbool usb_emulated_poll(uint8_t iface_num, short dir) {
static secbool usb_emulated_poll_read(uint8_t iface_num) {
struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, dir, 0},
{usb_ifaces[iface_num].sock, POLLIN, 0},
};
int r = poll(fds, 1, 0);
return sectrue * (r > 0);
}
int res = poll(fds, 1, 0);

if (res <= 0) {
return secfalse;
}

static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
struct sockaddr_in si;
socklen_t sl = sizeof(si);
ssize_t r = recvfrom(usb_ifaces[iface_num].sock, buf, len, MSG_DONTWAIT,
ssize_t r = recvfrom(usb_ifaces[iface_num].sock, usb_ifaces[iface_num].msg,
sizeof(usb_ifaces[iface_num].msg), MSG_DONTWAIT,
(struct sockaddr *)&si, &sl);
if (r < 0) {
return r;
if (r <= 0) {
return secfalse;
}

usb_ifaces[iface_num].si_other = si;
usb_ifaces[iface_num].slen = sl;
static const char *ping_req = "PINGPING";
static const char *ping_resp = "PONGPONG";
if (r == strlen(ping_req) && 0 == memcmp(ping_req, buf, strlen(ping_req))) {
if (r == strlen(ping_req) &&
0 == memcmp(ping_req, usb_ifaces[iface_num].msg, strlen(ping_req))) {
if (usb_ifaces[iface_num].slen > 0) {
sendto(usb_ifaces[iface_num].sock, ping_resp, strlen(ping_resp),
MSG_DONTWAIT,
(const struct sockaddr *)&usb_ifaces[iface_num].si_other,
usb_ifaces[iface_num].slen);
}
return 0;
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
return secfalse;
}
return r;

usb_ifaces[iface_num].msg_len = r;

return sectrue;
}

static secbool usb_emulated_poll_write(uint8_t iface_num) {
struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, POLLOUT, 0},
};
int r = poll(fds, 1, 0);
return sectrue * (r > 0);
}

static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
if (usb_ifaces[iface_num].msg_len > 0) {
if (usb_ifaces[iface_num].msg_len < len) {
len = usb_ifaces[iface_num].msg_len;
}
memcpy(buf, usb_ifaces[iface_num].msg, len);
usb_ifaces[iface_num].msg_len = 0;
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
return len;
}

return 0;
}

static int usb_emulated_write(uint8_t iface_num, const uint8_t *buf,
Expand All @@ -184,31 +218,31 @@ secbool usb_hid_can_read(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLIN);
return usb_emulated_poll_read(iface_num);
}

secbool usb_webusb_can_read(uint8_t iface_num) {
if (iface_num >= USBD_MAX_NUM_INTERFACES ||
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLIN);
return usb_emulated_poll_read(iface_num);
}

secbool usb_hid_can_write(uint8_t iface_num) {
if (iface_num >= USBD_MAX_NUM_INTERFACES ||
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLOUT);
return usb_emulated_poll_write(iface_num);
}

secbool usb_webusb_can_write(uint8_t iface_num) {
if (iface_num >= USBD_MAX_NUM_INTERFACES ||
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLOUT);
return usb_emulated_poll_write(iface_num);
}

int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
Expand Down
45 changes: 45 additions & 0 deletions core/embed/upymod/modtrezorio/modtrezorio-hid.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,46 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj,
mod_trezorio_HID_write);

/// def read(self, buf: bytearray, offset: int = 0) -> int:
/// """
/// Reads message using USB HID (device) or UDP (emulator).
/// """
STATIC mp_obj_t mod_trezorio_HID_read(size_t n_args, const mp_obj_t *args) {
mp_obj_HID_t *o = MP_OBJ_TO_PTR(args[0]);
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);

int offset = 0;
if (n_args >= 2) {
offset = mp_obj_get_int(args[2]);
}

if (offset < 0) {
mp_raise_ValueError("Negative offset not allowed");
}

if (offset > buf.len) {
mp_raise_ValueError("Offset out of bounds");
}

uint32_t buffer_space = buf.len - offset;
matejcik marked this conversation as resolved.
Show resolved Hide resolved

if (buffer_space < USB_PACKET_LEN) {
mp_raise_ValueError("Buffer too small");
}

ssize_t r = usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset],
USB_PACKET_LEN);

if (r != USB_PACKET_LEN) {
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
}

return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_HID_read_obj, 2, 3,
mod_trezorio_HID_read);

/// def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
/// """
/// Sends message using USB HID (device) or UDP (emulator).
Expand All @@ -158,12 +198,17 @@ STATIC mp_obj_t mod_trezorio_HID_write_blocking(mp_obj_t self, mp_obj_t msg,
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorio_HID_write_blocking_obj,
mod_trezorio_HID_write_blocking);

/// PACKET_LEN: ClassVar[int]
/// """Length of one USB packet."""

STATIC const mp_rom_map_elem_t mod_trezorio_HID_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_HID_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_HID_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_HID_read_obj)},
{MP_ROM_QSTR(MP_QSTR_write_blocking),
MP_ROM_PTR(&mod_trezorio_HID_write_blocking_obj)},
{MP_ROM_QSTR(MP_QSTR_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_HID_locals_dict,
mod_trezorio_HID_locals_dict_table);
Expand Down
28 changes: 7 additions & 21 deletions core/embed/upymod/modtrezorio/modtrezorio-poll.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,29 +166,15 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref,
}
#endif
else if (mode == POLL_READ) {
if (sectrue == usb_hid_can_read(iface)) {
uint8_t buf[64] = {0};
int len = usb_hid_read(iface, buf, sizeof(buf));
if (len > 0) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_obj_new_bytes(buf, len);
return mp_const_true;
}
} else if (sectrue == usb_webusb_can_read(iface)) {
uint8_t buf[64] = {0};
int len = usb_webusb_read(iface, buf, sizeof(buf));
if (len > 0) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_obj_new_bytes(buf, len);
return mp_const_true;
}
}
} else if (mode == POLL_WRITE) {
if (sectrue == usb_hid_can_write(iface)) {
if ((sectrue == usb_hid_can_read(iface)) ||
(sectrue == usb_webusb_can_read(iface))) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_const_none;
ret->items[1] = MP_OBJ_NEW_SMALL_INT(USB_PACKET_LEN);
return mp_const_true;
} else if (sectrue == usb_webusb_can_write(iface)) {
}
} else if (mode == POLL_WRITE) {
if ((sectrue == usb_hid_can_write(iface)) ||
(sectrue == usb_webusb_can_write(iface))) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_const_none;
return mp_const_true;
Expand Down
45 changes: 45 additions & 0 deletions core/embed/upymod/modtrezorio/modtrezorio-webusb.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,55 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj,
mod_trezorio_WebUSB_write);

/// def read(self, buf: bytearray, offset: int = 0) -> int:
/// """
/// Reads message using USB WebUSB (device) or UDP (emulator).
/// """
STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) {
mp_obj_WebUSB_t *o = MP_OBJ_TO_PTR(args[0]);
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);

int offset = 0;
if (n_args >= 2) {
offset = mp_obj_get_int(args[2]);
}

if (offset < 0) {
mp_raise_ValueError("Negative offset not allowed");
}

if (offset > buf.len) {
mp_raise_ValueError("Offset out of bounds");
}

uint32_t buffer_space = buf.len - offset;
matejcik marked this conversation as resolved.
Show resolved Hide resolved

if (buffer_space < USB_PACKET_LEN) {
mp_raise_ValueError("Buffer too small");
}

ssize_t r = usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset],
USB_PACKET_LEN);

if (r != USB_PACKET_LEN) {
mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length");
}

return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 2, 3,
mod_trezorio_WebUSB_read);

/// PACKET_LEN: ClassVar[int]
/// """Length of one USB packet."""

STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_WebUSB_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_WebUSB_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_WebUSB_read_obj)},
{MP_ROM_QSTR(MP_QSTR_PACKET_LEN), MP_ROM_INT(USB_PACKET_LEN)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_WebUSB_locals_dict,
mod_trezorio_WebUSB_locals_dict_table);
Expand Down
14 changes: 14 additions & 0 deletions core/mocks/generated/trezorio/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ class HID:
Sends message using USB HID (device) or UDP (emulator).
"""

def read(self, buf: bytearray, offset: int = 0) -> int:
"""
Reads message using USB HID (device) or UDP (emulator).
"""

def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
"""
Sends message using USB HID (device) or UDP (emulator).
"""
PACKET_LEN: ClassVar[int]
"""Length of one USB packet."""


# upymod/modtrezorio/modtrezorio-poll.h
Expand Down Expand Up @@ -148,6 +155,13 @@ class WebUSB:
"""
Sends message using USB WebUSB (device) or UDP (emulator).
"""

def read(self, buf: bytearray, offset: int = 0) -> int:
"""
Reads message using USB WebUSB (device) or UDP (emulator).
"""
PACKET_LEN: ClassVar[int]
"""Length of one USB packet."""
from . import fatfs, haptic, sdcard
POLL_READ: int # wait until interface is readable and return read data
POLL_WRITE: int # wait until interface is writable
Expand Down
8 changes: 6 additions & 2 deletions core/src/apps/webauthn/fido2.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ async def _read_cmd(iface: HID) -> Cmd | None:
read = loop.wait(iface.iface_num() | io.POLL_READ)

# wait for incoming command indefinitely
buf = await read
msg_len = await read
buf = bytearray(msg_len)
iface.read(buf, 0)
while True:
ifrm = overlay_struct(bytearray(buf), desc_init)
bcnt = ifrm.bcnt
Expand Down Expand Up @@ -415,7 +417,9 @@ async def _read_cmd(iface: HID) -> Cmd | None:
read.timeout_ms = _CTAP_HID_TIMEOUT_MS
while datalen < bcnt:
try:
buf = await read
msg_len = await read
buf = bytearray(msg_len)
iface.read(buf, 0)
except loop.Timeout:
if __debug__:
warning(__name__, "_ERR_MSG_TIMEOUT")
Expand Down
11 changes: 8 additions & 3 deletions core/src/trezor/wire/codec/codec_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if TYPE_CHECKING:
from trezorio import WireInterface

_REP_LEN = const(64)
_REP_LEN = io.WebUSB.PACKET_LEN

_REP_MARKER = const(63) # ord('?')
_REP_MAGIC = const(35) # org('#')
Expand All @@ -23,9 +23,12 @@ class CodecError(WireError):

async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)
report = bytearray(_REP_LEN)

# wait for initial report
report = await read
msg_len = await read
assert msg_len == len(report)
iface.read(report, 0)
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
Expand All @@ -50,7 +53,9 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag

while nread < msize:
# wait for continuation report
report = await read
msg_len = await read
assert msg_len == len(report)
iface.read(report, 0)
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")

Expand Down
Loading
Loading