From 07b82994ffdba887f48a0ec653015ecadd7bc4ea Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Fri, 12 Dec 2025 18:48:55 +0100 Subject: [PATCH] feat: idiomatic zig --- src/protocol.zig | 96 +++++++++++++++++++++++++++++------------------- 1 file changed, 59 insertions(+), 37 deletions(-) diff --git a/src/protocol.zig b/src/protocol.zig index 907711e..97577b1 100644 --- a/src/protocol.zig +++ b/src/protocol.zig @@ -15,48 +15,70 @@ const VERSION: u8 = 1; const Data = struct { message: []const u8, + allocator: ?std.mem.Allocator, + + pub fn init(allocator: ?std.mem.Allocator, message: []const u8) Data { + return .{ + .message = message, + .allocator = allocator, + }; + } + + pub fn serialize(self: Data, allocator: std.mem.Allocator) ![]u8 { + const messageLen = self.message.len; + if (messageLen >= std.math.maxInt(u8)) { + return error.MessageToLong; + } + const data = try allocator.alloc(u8, 2 + self.message.len); + data[0] = VERSION; + data[1] = @intCast(self.message.len); + @memcpy(data[2..], self.message); + return data; + } + + pub fn deserialize(allocator: std.mem.Allocator, input: []const u8) !Data { + if (input.len < 2) { + return error.ProtocolFailure; + } + + if (input[0] != VERSION) { + return error.VersionMismatch; + } + + const messageLen = input[1]; + const message = try allocator.alloc(u8, messageLen); + @memcpy(message, input[2..(messageLen + 2)]); + + return Data.init(allocator, message); + } + + pub fn deinit(self: Data) void { + if (self.allocator) |alloc| { + alloc.free(self.message); + } + } }; -const SerializationError = error{ VersionDoesNotMatchError, OutOfMemory, OtherError }; - -fn deserialize(allocator: std.mem.Allocator, input: []u8) SerializationError!Data { - if (input.len == 0) { - return SerializationError.OtherError; - } - - if (input[0] != VERSION) { - return SerializationError.VersionDoesNotMatchError; - } - - const messageLen = input[1]; - - const message = try allocator.alloc(u8, messageLen); - - @memcpy(message, input[2..(messageLen + 2)]); - - return .{ .message = message }; -} -fn serialize(allocator: std.mem.Allocator, input: Data) ![]u8 { - const messageLen = input.message.len; - if (messageLen >= std.math.maxInt(u8)) { - return SerializationError.OtherError; - } - const data = try allocator.alloc(u8, 2 + input.message.len); - data[0] = VERSION; - data[1] = @intCast(input.message.len); - const a = input.message; - @memcpy(data[2..], a); - return data; -} - -test "expect DataV1 can be serialized and deserialized" { +test "should round trip" { const message = "Test"; - const expected: Data = .{ .message = message }; - const serialized = try serialize(std.testing.allocator, expected); + const expected = Data.init(null, message); + + const serialized = try expected.serialize(std.testing.allocator); defer std.testing.allocator.free(serialized); - const actual = try deserialize(std.testing.allocator, serialized); - defer std.testing.allocator.free(actual.message); + const actual = try Data.deserialize(std.testing.allocator, serialized); + defer actual.deinit(); try std.testing.expectEqualDeep(expected.message, actual.message); } + +test "should return deserialize error " { + try std.testing.expectError(error.ProtocolFailure, Data.deserialize(std.testing.allocator, &.{})); + try std.testing.expectError(error.ProtocolFailure, Data.deserialize(std.testing.allocator, &.{VERSION})); + try std.testing.expectError(error.VersionMismatch, Data.deserialize(std.testing.allocator, &.{ VERSION + 2, 0 })); +} + +test "should deserialize empty message" { + const data = try Data.deserialize(std.testing.allocator, &.{ VERSION, 0 }); + try std.testing.expectEqualStrings("", data.message); +}