Skip to content

Commit

Permalink
refacto: generic de/serialize messages (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
ybensacq authored Oct 4, 2024
1 parent d82bc63 commit bf53548
Show file tree
Hide file tree
Showing 20 changed files with 84 additions and 222 deletions.
34 changes: 6 additions & 28 deletions src/network/protocol/messages/block.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
const std = @import("std");
const native_endian = @import("builtin").target.cpu.arch.endian();
const protocol = @import("../lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const genericSerialize = @import("lib.zig").genericSerialize;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;

const ServiceFlags = protocol.ServiceFlags;

Expand Down Expand Up @@ -28,15 +31,7 @@ pub const BlockMessage = struct {
}

pub fn checksum(self: BlockMessage) [4]u8 {
var digest: [32]u8 = undefined;
var hasher = Sha256.init(.{});
const writer = hasher.writer();
self.serializeToWriter(writer) catch unreachable; // Sha256.write is infaible
hasher.final(&digest);

Sha256.hash(&digest, &digest, .{});

return digest[0..4].*;
return genericChecksum(self);
}

pub fn deinit(self: *BlockMessage, allocator: std.mem.Allocator) void {
Expand Down Expand Up @@ -64,24 +59,9 @@ pub const BlockMessage = struct {
}
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const Self, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
try self.serializeToWriter(fbs.writer());
}

/// Serialize a message as bytes and return them.
pub fn serialize(self: *const BlockMessage, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
return genericSerialize(self, allocator);
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !BlockMessage {
Expand Down Expand Up @@ -112,8 +92,7 @@ pub const BlockMessage = struct {

/// Deserialize bytes into a `VersionMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
var fbs = std.io.fixedBufferStream(bytes);
return try Self.deserializeReader(allocator, fbs.reader());
return genericDeserializeSlice(Self, allocator, bytes);
}

pub fn hintSerializedLen(self: BlockMessage) usize {
Expand All @@ -128,7 +107,6 @@ pub const BlockMessage = struct {
};

// TESTS

test "ok_full_flow_BlockMessage" {
const OpCode = @import("../../../script/opcodes/constant.zig").Opcode;
const allocator = std.testing.allocator;
Expand Down
4 changes: 2 additions & 2 deletions src/network/protocol/messages/cmpctblock.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const Sha256 = std.crypto.hash.sha2.Sha256;
const BlockHeader = @import("../../../types/block_header.zig");
const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
const genericChecksum = @import("lib.zig").genericChecksum;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;

pub const CmpctBlockMessage = struct {
header: BlockHeader,
Expand Down Expand Up @@ -118,8 +119,7 @@ pub const CmpctBlockMessage = struct {
}

pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
var fbs = std.io.fixedBufferStream(bytes);
return try Self.deserializeReader(allocator, fbs.reader());
return genericDeserializeSlice(Self, allocator, bytes);
}

pub fn hintSerializedLen(self: *const Self) usize {
Expand Down
23 changes: 4 additions & 19 deletions src/network/protocol/messages/feefilter.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ const std = @import("std");
const protocol = @import("../lib.zig");
const Sha256 = std.crypto.hash.sha2.Sha256;
const genericChecksum = @import("lib.zig").genericChecksum;
const genericSerialize = @import("lib.zig").genericSerialize;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;

/// FeeFilterMessage represents the "feefilter" message
///
Expand All @@ -22,13 +24,6 @@ pub const FeeFilterMessage = struct {
return genericChecksum(self);
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const Self, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
try self.serializeToWriter(fbs.writer());
}

/// Serialize the message as bytes and write them to the Writer.
pub fn serializeToWriter(self: *const Self, w: anytype) !void {
Expand All @@ -37,14 +32,7 @@ pub const FeeFilterMessage = struct {

/// Serialize a message as bytes and return them.
pub fn serialize(self: *const Self, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
return genericSerialize(self, allocator);
}

/// Deserialize a Reader bytes as a `FeeFilterMessage`
Expand All @@ -60,10 +48,7 @@ pub const FeeFilterMessage = struct {

/// Deserialize bytes into a `FeeFilterMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
var fbs = std.io.fixedBufferStream(bytes);
const reader = fbs.reader();

return try Self.deserializeReader(allocator, reader);
return genericDeserializeSlice(Self, allocator, bytes);
}

pub fn hintSerializedLen(_: *const Self) usize {
Expand Down
17 changes: 4 additions & 13 deletions src/network/protocol/messages/filteradd.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
const std = @import("std");
const protocol = @import("../lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const genericSerialize = @import("lib.zig").genericSerialize;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;

/// FilterAddMessage represents the "filteradd" message
///
Expand Down Expand Up @@ -29,17 +31,7 @@ pub const FilterAddMessage = struct {

/// Serialize a message as bytes and return them.
pub fn serialize(self: *const Self, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();
const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);
try self.serializeToSlice(ret);
return ret;
}

/// Serialize a message as bytes and write them to the buffer.
pub fn serializeToSlice(self: *const Self, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
try self.serializeToWriter(fbs.writer());
return genericSerialize(self, allocator);
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !Self {
Expand All @@ -51,8 +43,7 @@ pub const FilterAddMessage = struct {
}

pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
var fbs = std.io.fixedBufferStream(bytes);
return try Self.deserializeReader(allocator, fbs.reader());
return genericDeserializeSlice(Self, allocator, bytes);
}

pub fn hintSerializedLen(self: *const Self) usize {
Expand Down
1 change: 0 additions & 1 deletion src/network/protocol/messages/filterclear.zig
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ pub const FilterClearMessage = struct {
};

// TESTS

test "ok_full_flow_FilterClearMessage" {
const allocator = std.testing.allocator;

Expand Down
20 changes: 4 additions & 16 deletions src/network/protocol/messages/filterload.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
const std = @import("std");
const protocol = @import("../lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const genericSerialize = @import("lib.zig").genericSerialize;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;

const Sha256 = std.crypto.hash.sha2.Sha256;
const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
Expand Down Expand Up @@ -43,20 +45,7 @@ pub const FilterLoadMessage = struct {

/// Serialize a message as bytes and return them.
pub fn serialize(self: *const Self, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
}

/// Serialize a message as bytes and write them to the buffer.
pub fn serializeToSlice(self: *const Self, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
try self.serializeToWriter(fbs.writer());
return genericSerialize(self, allocator);
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !Self {
Expand All @@ -83,8 +72,7 @@ pub const FilterLoadMessage = struct {
}

pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
var fbs = std.io.fixedBufferStream(bytes);
return try Self.deserializeReader(allocator, fbs.reader());
return genericDeserializeSlice(Self, allocator, bytes);
}

pub fn hintSerializedLen(self: *const Self) usize {
Expand Down
1 change: 0 additions & 1 deletion src/network/protocol/messages/getaddr.zig
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ pub const GetaddrMessage = struct {
};

// TESTS

test "ok_full_flow_GetaddrMessage" {
const allocator = std.testing.allocator;

Expand Down
30 changes: 7 additions & 23 deletions src/network/protocol/messages/getblocks.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
const std = @import("std");
const protocol = @import("../lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const genericSerialize = @import("lib.zig").genericSerialize;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;

const Sha256 = std.crypto.hash.sha2.Sha256;

Expand All @@ -14,6 +16,8 @@ pub const GetblocksMessage = struct {
header_hashes: [][32]u8,
stop_hash: [32]u8,

const Self = @This();

pub fn name() *const [12]u8 {
return protocol.CommandNames.GETBLOCKS ++ [_]u8{0} ** 5;
}
Expand Down Expand Up @@ -50,23 +54,7 @@ pub const GetblocksMessage = struct {

/// Serialize a message as bytes and return them.
pub fn serialize(self: *const GetblocksMessage, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const GetblocksMessage, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
const writer = fbs.writer();
try self.serializeToWriter(writer);
return genericSerialize(self, allocator);
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !GetblocksMessage {
Expand Down Expand Up @@ -96,11 +84,8 @@ pub const GetblocksMessage = struct {
}

/// Deserialize bytes into a `GetblocksMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !GetblocksMessage {
var fbs = std.io.fixedBufferStream(bytes);
const reader = fbs.reader();

return try GetblocksMessage.deserializeReader(allocator, reader);
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
return genericDeserializeSlice(Self, allocator, bytes);
}

pub fn hintSerializedLen(self: *const GetblocksMessage) usize {
Expand Down Expand Up @@ -134,7 +119,6 @@ pub const GetblocksMessage = struct {
};

// TESTS

test "ok_full_flow_GetBlocksMessage" {
const allocator = std.testing.allocator;

Expand Down
4 changes: 2 additions & 2 deletions src/network/protocol/messages/headers.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const std = @import("std");
const protocol = @import("../lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;

const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;

Expand Down Expand Up @@ -84,8 +85,7 @@ pub const HeadersMessage = struct {

/// Deserialize bytes into a `HeaderMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
var fbs = std.io.fixedBufferStream(bytes);
return try Self.deserializeReader(allocator, fbs.reader());
return genericDeserializeSlice(Self, allocator, bytes);
}

pub fn hintSerializedLen(self: Self) usize {
Expand Down
27 changes: 27 additions & 0 deletions src/network/protocol/messages/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,30 @@ pub fn genericChecksum(m: anytype) [4]u8 {

return digest[0..4].*;
}

pub fn genericSerialize(m: anytype, allocator: std.mem.Allocator) ![]u8 {
std.debug.print("Type m: {}\n", .{@TypeOf(m)});

comptime {
if (!std.meta.hasMethod(@TypeOf(m), "hintSerializedLen")) @compileError("Expects m to have fn 'hintSerializedLen'.");
if (!std.meta.hasMethod(@TypeOf(m), "serializeToWriter")) @compileError("Expects m to have fn 'serializeToWriter'.");
}
const serialized_len = m.hintSerializedLen();

const buffer = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(buffer);

var fbs = std.io.fixedBufferStream(buffer);
try m.serializeToWriter(fbs.writer());

return buffer;
}

pub fn genericDeserializeSlice(comptime T: type, allocator: std.mem.Allocator, bytes: []const u8) !T {
if (!std.meta.hasMethod(T, "deserializeReader")) @compileError("Expects T to have fn 'deserializeReader'.");

var fbs = std.io.fixedBufferStream(bytes);
const reader = fbs.reader();

return try T.deserializeReader(allocator, reader);
}
1 change: 0 additions & 1 deletion src/network/protocol/messages/mempool.zig
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ pub const MempoolMessage = struct {
};

// TESTS

test "ok_full_flow_MempoolMessage" {
const allocator = std.testing.allocator;

Expand Down
Loading

0 comments on commit bf53548

Please sign in to comment.