From 953d33302319f0710de7324bbc48cd9130801be1 Mon Sep 17 00:00:00 2001 From: Raj Nishtala Date: Tue, 14 Jan 2025 13:51:58 -0500 Subject: [PATCH] Switch back to using UnmarshalText for validating compression types --- config/configcompression/compressiontype.go | 21 ++++++++++++------- .../configcompression/compressiontype_test.go | 6 +++--- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/config/configcompression/compressiontype.go b/config/configcompression/compressiontype.go index a16d9f7d1d1..a7edac428ed 100644 --- a/config/configcompression/compressiontype.go +++ b/config/configcompression/compressiontype.go @@ -31,17 +31,24 @@ const ( // IsCompressed returns false if CompressionType is nil, none, or empty. // Otherwise, returns true. -func (t Type) IsCompressed() bool { - return t != typeEmpty && t != typeNone +func (ct *Type) IsCompressed() bool { + return *ct != typeEmpty && *ct != typeNone } -func (t Type) Validate() error { - switch t { - case TypeGzip, TypeZlib, TypeDeflate, TypeSnappy, TypeZstd, TypeLz4, - typeNone, typeEmpty: +func (ct *Type) UnmarshalText(in []byte) error { + typ := Type(in) + if typ == TypeGzip || + typ == TypeZlib || + typ == TypeDeflate || + typ == TypeSnappy || + typ == TypeZstd || + typ == TypeLz4 || + typ == typeNone || + typ == typeEmpty { + *ct = typ return nil } - return fmt.Errorf("unsupported compression type %q", t) + return fmt.Errorf("unsupported compression type %q", typ) } func (t Type) ValidateParams(p CompressionParams) error { diff --git a/config/configcompression/compressiontype_test.go b/config/configcompression/compressiontype_test.go index 0ea54ce1c3e..2fb9976eb4e 100644 --- a/config/configcompression/compressiontype_test.go +++ b/config/configcompression/compressiontype_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestValidate(t *testing.T) { +func TestUnmarshalText(t *testing.T) { tests := []struct { name string compressionName []byte @@ -72,8 +72,8 @@ func TestValidate(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - temp := Type(tt.compressionName) - err := temp.Validate() + temp := typeNone + err := temp.UnmarshalText(tt.compressionName) if tt.shouldError { assert.Error(t, err) return