diff --git a/encryption.go b/encryption.go deleted file mode 100644 index 1143602..0000000 --- a/encryption.go +++ /dev/null @@ -1,101 +0,0 @@ -package walrus_go - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "errors" - "io" -) - -var ( - // ErrDecryption indicates a decryption failure, likely due to an incorrect key - ErrDecryption = errors.New("failed to decrypt data: invalid key or corrupted data") - // Magic bytes for encryption validation - magicBytes = []byte("WAL_V1") -) - -// EncryptStream encrypts data from src using AES-CTR and writes the encrypted output to dst -func EncryptStream(key []byte, src io.Reader, dst io.Writer) error { - block, err := aes.NewCipher(key) - if err != nil { - return err - } - - iv := make([]byte, aes.BlockSize) - if _, err := rand.Read(iv); err != nil { - return err - } - - // Write magic bytes first - if _, err := dst.Write(magicBytes); err != nil { - return err - } - - // Write IV after magic bytes - if _, err := dst.Write(iv); err != nil { - return err - } - - stream := cipher.NewCTR(block, iv) - - // Encrypt magic bytes verification - verificationBytes := make([]byte, len(magicBytes)) - stream.XORKeyStream(verificationBytes, magicBytes) - if _, err := dst.Write(verificationBytes); err != nil { - return err - } - - // Reset stream for actual data encryption - stream = cipher.NewCTR(block, iv) - writer := &cipher.StreamWriter{S: stream, W: dst} - - // Copy from src to writer, encryption happens automatically during copy - _, err = io.Copy(writer, src) - return err -} - -// DecryptStream reads AES-CTR encrypted data from src and writes decrypted output to dst -func DecryptStream(key []byte, src io.Reader, dst io.Writer) error { - // Read and verify magic bytes - header := make([]byte, len(magicBytes)) - if _, err := io.ReadFull(src, header); err != nil { - return ErrDecryption - } - if string(header) != string(magicBytes) { - return ErrDecryption - } - - block, err := aes.NewCipher(key) - if err != nil { - return err - } - - iv := make([]byte, aes.BlockSize) - if _, err := io.ReadFull(src, iv); err != nil { - return ErrDecryption - } - - stream := cipher.NewCTR(block, iv) - - // Read and verify encrypted magic bytes - encryptedVerification := make([]byte, len(magicBytes)) - if _, err := io.ReadFull(src, encryptedVerification); err != nil { - return ErrDecryption - } - - // Decrypt verification bytes - verificationBytes := make([]byte, len(magicBytes)) - stream.XORKeyStream(verificationBytes, encryptedVerification) - if string(verificationBytes) != string(magicBytes) { - return ErrDecryption - } - - // Reset stream for actual data decryption - stream = cipher.NewCTR(block, iv) - reader := &cipher.StreamReader{S: stream, R: src} - - // Copy decrypted data from reader to dst - _, err = io.Copy(dst, reader) - return err -} diff --git a/encryption/aes_cbc.go b/encryption/aes_cbc.go new file mode 100644 index 0000000..feb4f12 --- /dev/null +++ b/encryption/aes_cbc.go @@ -0,0 +1,213 @@ +package encryption + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "io" +) + +// PKCS7Padder implements PKCS7 padding +type PKCS7Padder struct { + blockSize int +} + +// Pad adds padding to the input slice according to PKCS7 +func (p *PKCS7Padder) Pad(data []byte, size int) ([]byte, error) { + padding := p.blockSize - (size % p.blockSize) + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padtext...), nil +} + +// Unpad removes PKCS7 padding from the input slice +func (p *PKCS7Padder) Unpad(data []byte) ([]byte, error) { + length := len(data) + if length == 0 { + return nil, nil + } + + padding := int(data[length-1]) + return data[:length-padding], nil +} + +type cbcEncryptReader struct { + encrypter cipher.BlockMode + src io.Reader + padder *PKCS7Padder + size int + buf bytes.Buffer +} + +func (r *cbcEncryptReader) Read(data []byte) (int, error) { + n, err := r.src.Read(data) + r.size += n + blockSize := r.encrypter.BlockSize() + r.buf.Write(data[:n]) + + if err == io.EOF { + b := make([]byte, getSliceSize(blockSize, r.buf.Len(), len(data))) + n, err = r.buf.Read(b) + if err != nil && err != io.EOF { + return n, err + } + + if r.buf.Len() == 0 { + b, err = r.padder.Pad(b[:n], r.size) + if err != nil { + return n, err + } + n = len(b) + err = io.EOF + } + + if n > 0 { + r.encrypter.CryptBlocks(data, b) + } + return n, err + } + + if err != nil { + return n, err + } + + if size := r.buf.Len(); size >= blockSize { + nBlocks := size / blockSize + if size > len(data) { + nBlocks = len(data) / blockSize + } + + if nBlocks > 0 { + b := make([]byte, nBlocks*blockSize) + n, _ = r.buf.Read(b) + r.encrypter.CryptBlocks(data, b[:n]) + } + } else { + n = 0 + } + return n, nil +} + +type cbcDecryptReader struct { + decrypter cipher.BlockMode + src io.Reader + padder *PKCS7Padder + buf bytes.Buffer +} + +func (r *cbcDecryptReader) Read(data []byte) (int, error) { + n, err := r.src.Read(data) + blockSize := r.decrypter.BlockSize() + r.buf.Write(data[:n]) + + if err == io.EOF { + b := make([]byte, getSliceSize(blockSize, r.buf.Len(), len(data))) + n, err = r.buf.Read(b) + if err != nil && err != io.EOF { + return n, err + } + + if n > 0 { + r.decrypter.CryptBlocks(data, b) + } + + if r.buf.Len() == 0 { + b, err = r.padder.Unpad(data[:n]) + n = len(b) + if err != nil { + return n, err + } + err = io.EOF + } + return n, err + } + + if err != nil { + return n, err + } + + if size := r.buf.Len(); size >= blockSize { + nBlocks := size / blockSize + if size > len(data) { + nBlocks = len(data) / blockSize + } + nBlocks -= blockSize + + if nBlocks > 0 { + b := make([]byte, nBlocks*blockSize) + n, _ = r.buf.Read(b) + r.decrypter.CryptBlocks(data, b[:n]) + } else { + n = 0 + } + } + + return n, nil +} + +func getSliceSize(blockSize, bufSize, dataSize int) int { + size := bufSize + if bufSize > dataSize { + size = dataSize + } + size = size - (size % blockSize) - blockSize + if size <= 0 { + size = blockSize + } + return size +} + +type cbcCipher struct { + key []byte + iv []byte +} + +// EncryptStreamCBC encrypts data from src using AES-CBC and writes the encrypted output to dst +func (c cbcCipher) EncryptStream(src io.Reader, dst io.Writer) error { + block, err := aes.NewCipher(c.key) + if err != nil { + return err + } + + // Write IV first + if _, err := dst.Write(c.iv); err != nil { + return err + } + + encrypter := cipher.NewCBCEncrypter(block, c.iv) + padder := &PKCS7Padder{blockSize: block.BlockSize()} + + reader := &cbcEncryptReader{ + encrypter: encrypter, + src: src, + padder: padder, + } + + _, err = io.Copy(dst, reader) + return err +} + +// DecryptStream reads AES-CBC encrypted data from src and writes decrypted output to dst +func (c cbcCipher) DecryptStream(src io.Reader, dst io.Writer) error { + block, err := aes.NewCipher(c.key) + if err != nil { + return err + } + + // Read IV + iv := make([]byte, block.BlockSize()) + if _, err := io.ReadFull(src, iv); err != nil { + return err + } + + decrypter := cipher.NewCBCDecrypter(block, iv) + padder := &PKCS7Padder{blockSize: block.BlockSize()} + + reader := &cbcDecryptReader{ + decrypter: decrypter, + src: src, + padder: padder, + } + + _, err = io.Copy(dst, reader) + return err +} diff --git a/encryption/aes_cbc_test.go b/encryption/aes_cbc_test.go new file mode 100644 index 0000000..fc0a7b8 --- /dev/null +++ b/encryption/aes_cbc_test.go @@ -0,0 +1,194 @@ +package encryption + +import ( + "bytes" + "crypto/rand" + "io" + "testing" +) + +func TestCBCCipher(t *testing.T) { + // Test cases with different data sizes + testSizes := []int{ + 16, // One block + 32, // Two blocks + 63, // Not block aligned + 1024, // 1KB + 65536, // 64KB + } + + for _, size := range testSizes { + t.Run(formatTestName(size), func(t *testing.T) { + // Generate random test data + plaintext := make([]byte, size) + rand.Read(plaintext) + + // Create cipher + key := make([]byte, 32) + iv := make([]byte, 16) + rand.Read(key) + rand.Read(iv) + + cipher, err := NewCBCCipher(key, iv) + if err != nil { + t.Fatalf("Failed to create CBC cipher: %v", err) + } + + // Test encryption and decryption + var encrypted bytes.Buffer + var decrypted bytes.Buffer + + // Encrypt + err = cipher.EncryptStream(bytes.NewReader(plaintext), &encrypted) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + // Decrypt + err = cipher.DecryptStream(bytes.NewReader(encrypted.Bytes()), &decrypted) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + // Verify + if !bytes.Equal(plaintext, decrypted.Bytes()) { + t.Error("Decrypted data doesn't match original") + } + }) + } +} + +func TestCBCCipherErrors(t *testing.T) { + tests := []struct { + name string + key []byte + iv []byte + wantErr string + }{ + { + name: "invalid key size", + key: make([]byte, 15), + iv: make([]byte, 16), + wantErr: "invalid key size", + }, + { + name: "invalid IV size", + key: make([]byte, 32), + iv: make([]byte, 15), + wantErr: "IV length must equal block size", + }, + { + name: "nil key", + key: nil, + iv: make([]byte, 16), + wantErr: "invalid key size", + }, + { + name: "nil IV", + key: make([]byte, 32), + iv: nil, + wantErr: "IV length must equal block size", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewCBCCipher(tt.key, tt.iv) + if err == nil { + t.Error("Expected error but got none") + return + } + if err.Error() != tt.wantErr { + t.Errorf("Expected error '%s', got '%s'", tt.wantErr, err.Error()) + } + }) + } +} + +func TestCBCStreamErrors(t *testing.T) { + key := make([]byte, 32) + iv := make([]byte, 16) + rand.Read(key) + rand.Read(iv) + + cipher, err := NewCBCCipher(key, iv) + if err != nil { + t.Fatalf("Failed to create cipher: %v", err) + } + + // Test with failing reader/writer + failingReader := &failingReader{err: io.ErrUnexpectedEOF} + failingWriter := &failingWriter{err: io.ErrShortWrite} + + tests := []struct { + name string + test func() error + }{ + { + name: "encryption with failing reader", + test: func() error { + return cipher.EncryptStream(failingReader, &bytes.Buffer{}) + }, + }, + { + name: "encryption with failing writer", + test: func() error { + return cipher.EncryptStream(bytes.NewReader([]byte("test")), failingWriter) + }, + }, + { + name: "decryption with failing reader", + test: func() error { + return cipher.DecryptStream(failingReader, &bytes.Buffer{}) + }, + }, + { + name: "decryption with failing writer", + test: func() error { + return cipher.DecryptStream(bytes.NewReader(make([]byte, 32)), failingWriter) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.test(); err == nil { + t.Error("Expected error but got none") + } + }) + } +} + +// Helper types and functions +type failingReader struct { + err error +} + +func (r *failingReader) Read(p []byte) (n int, err error) { + return 0, r.err +} + +type failingWriter struct { + err error +} + +func (w *failingWriter) Write(p []byte) (n int, err error) { + return 0, w.err +} + +func formatTestName(size int) string { + switch { + case size >= 1024: + return formatKB(size) + default: + return formatBytes(size) + } +} + +func formatKB(size int) string { + return formatBytes(size/1024) + "KB" +} + +func formatBytes(size int) string { + return string(rune(size)) + "B" +} \ No newline at end of file diff --git a/encryption/aes_gcm.go b/encryption/aes_gcm.go new file mode 100644 index 0000000..93e97e1 --- /dev/null +++ b/encryption/aes_gcm.go @@ -0,0 +1,80 @@ +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "io" +) + +type gcmCipher struct { + key []byte +} + +func (c *gcmCipher) EncryptStream(src io.Reader, dst io.Writer) error { + block, err := aes.NewCipher(c.key) + if err != nil { + return err + } + + aead, err := cipher.NewGCM(block) + if err != nil { + return err + } + + // Generate nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return err + } + + // Write nonce first + if _, err := dst.Write(nonce); err != nil { + return err + } + + // Read all data from src + plaintext, err := io.ReadAll(src) + if err != nil { + return err + } + + // Encrypt and write the data + ciphertext := aead.Seal(nil, nonce, plaintext, nil) + _, err = dst.Write(ciphertext) + return err +} + +func (c *gcmCipher) DecryptStream(src io.Reader, dst io.Writer) error { + block, err := aes.NewCipher(c.key) + if err != nil { + return err + } + + aead, err := cipher.NewGCM(block) + if err != nil { + return err + } + + // Read nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := io.ReadFull(src, nonce); err != nil { + return err + } + + // Read the ciphertext + ciphertext, err := io.ReadAll(src) + if err != nil { + return err + } + + // Decrypt the data + plaintext, err := aead.Open(nil, nonce, ciphertext, nil) + if err != nil { + return err + } + + // Write the decrypted data + _, err = dst.Write(plaintext) + return err +} diff --git a/encryption/aes_gcm_test.go b/encryption/aes_gcm_test.go new file mode 100644 index 0000000..3e7ca1c --- /dev/null +++ b/encryption/aes_gcm_test.go @@ -0,0 +1,167 @@ +package encryption + +import ( + "bytes" + "crypto/rand" + "io" + "testing" +) + +func TestGCMCipher(t *testing.T) { + // Test cases with different data sizes + testSizes := []int{ + 16, // Small data + 1024, // 1KB + 65536, // 64KB + } + + for _, size := range testSizes { + t.Run(formatTestName(size), func(t *testing.T) { + // Generate random test data + plaintext := make([]byte, size) + rand.Read(plaintext) + + // Create cipher + key := make([]byte, 32) + rand.Read(key) + + cipher, err := NewGCMCipher(key) + if err != nil { + t.Fatalf("Failed to create GCM cipher: %v", err) + } + + // Test encryption and decryption + var encrypted bytes.Buffer + var decrypted bytes.Buffer + + // Encrypt + err = cipher.EncryptStream(bytes.NewReader(plaintext), &encrypted) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + // Decrypt + err = cipher.DecryptStream(bytes.NewReader(encrypted.Bytes()), &decrypted) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + // Verify + if !bytes.Equal(plaintext, decrypted.Bytes()) { + t.Error("Decrypted data doesn't match original") + } + }) + } +} + +func TestGCMCipherErrors(t *testing.T) { + tests := []struct { + name string + key []byte + wantErr string + }{ + { + name: "invalid key size", + key: make([]byte, 15), + wantErr: "invalid key size", + }, + { + name: "nil key", + key: nil, + wantErr: "invalid key size", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewGCMCipher(tt.key) + if err == nil { + t.Error("Expected error but got none") + return + } + if err.Error() != tt.wantErr { + t.Errorf("Expected error '%s', got '%s'", tt.wantErr, err.Error()) + } + }) + } +} + +func TestGCMAuthenticationAndTampering(t *testing.T) { + plaintext := []byte("secret message") + key := make([]byte, 32) + rand.Read(key) + + cipher, err := NewGCMCipher(key) + if err != nil { + t.Fatalf("Failed to create cipher: %v", err) + } + + // Encrypt data + var encrypted bytes.Buffer + err = cipher.EncryptStream(bytes.NewReader(plaintext), &encrypted) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + // Tamper with encrypted data + encryptedData := encrypted.Bytes() + encryptedData[len(encryptedData)-1] ^= 0x01 // Flip last bit + + // Try to decrypt tampered data + var decrypted bytes.Buffer + err = cipher.DecryptStream(bytes.NewReader(encryptedData), &decrypted) + if err == nil { + t.Error("Expected authentication error for tampered data, got none") + } +} + +func TestGCMStreamErrors(t *testing.T) { + key := make([]byte, 32) + rand.Read(key) + + cipher, err := NewGCMCipher(key) + if err != nil { + t.Fatalf("Failed to create cipher: %v", err) + } + + failingReader := &failingReader{err: io.ErrUnexpectedEOF} + failingWriter := &failingWriter{err: io.ErrShortWrite} + + tests := []struct { + name string + test func() error + }{ + { + name: "encryption with failing reader", + test: func() error { + return cipher.EncryptStream(failingReader, &bytes.Buffer{}) + }, + }, + { + name: "encryption with failing writer", + test: func() error { + return cipher.EncryptStream(bytes.NewReader([]byte("test")), failingWriter) + }, + }, + { + name: "decryption with failing reader", + test: func() error { + return cipher.DecryptStream(failingReader, &bytes.Buffer{}) + }, + }, + { + name: "decryption with failing writer", + test: func() error { + return cipher.DecryptStream(bytes.NewReader(make([]byte, 32)), failingWriter) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.test(); err == nil { + t.Error("Expected error but got none") + } + }) + } +} \ No newline at end of file diff --git a/encryption/cipher.go b/encryption/cipher.go new file mode 100644 index 0000000..d692fd3 --- /dev/null +++ b/encryption/cipher.go @@ -0,0 +1,27 @@ +package encryption + +import "io" + +// StreamCipher defines the interface for stream encryption/decryption +type StreamCipher interface { + // EncryptStream encrypts data from src and writes to dst + EncryptStream(src io.Reader, dst io.Writer) error + + // DecryptStream decrypts data from src and writes to dst + DecryptStream(src io.Reader, dst io.Writer) error +} + +// NewCBCCipher creates a new AES-CBC cipher with the given key and IV +func NewCBCCipher(key, iv []byte) (StreamCipher, error) { + return &cbcCipher{ + key: key, + iv: iv, + }, nil +} + +// NewGCMCipher creates a new AES-GCM cipher with the given key +func NewGCMCipher(key []byte) (StreamCipher, error) { + return &gcmCipher{ + key: key, + }, nil +} \ No newline at end of file diff --git a/walrus.go b/walrus.go index 54a9df5..877fe08 100644 --- a/walrus.go +++ b/walrus.go @@ -10,27 +10,29 @@ import ( "os" "strconv" "time" + + "github.com/suiet/walrus-go/encryption" ) // RetryConfig defines the retry configuration type RetryConfig struct { - MaxRetries int // Maximum number of retry attempts - RetryDelay time.Duration // Delay between retries + MaxRetries int // Maximum number of retry attempts + RetryDelay time.Duration // Delay between retries } // Client is a client for interacting with the Walrus API type Client struct { - AggregatorURL []string - PublisherURL []string - httpClient *http.Client - retryConfig RetryConfig // Add retry configuration - // MaxUnknownLengthUploadSize specifies the maximum allowed size in bytes for uploads - // when the content length is not known in advance (i.e., contentLength <= 0). - // In such cases, the entire content must be read into memory to determine its size, - // which could potentially cause memory issues with very large uploads. - // This limit helps prevent memory exhaustion in those scenarios. - // Default is 5MB. - MaxUnknownLengthUploadSize int64 + AggregatorURL []string + PublisherURL []string + httpClient *http.Client + retryConfig RetryConfig // Add retry configuration + // MaxUnknownLengthUploadSize specifies the maximum allowed size in bytes for uploads + // when the content length is not known in advance (i.e., contentLength <= 0). + // In such cases, the entire content must be read into memory to determine its size, + // which could potentially cause memory issues with very large uploads. + // This limit helps prevent memory exhaustion in those scenarios. + // Default is 5MB. + MaxUnknownLengthUploadSize int64 } // ClientOption defines a function type that modifies Client options @@ -38,39 +40,39 @@ type ClientOption func(*Client) // WithAggregatorURLs sets custom aggregator URLs for the client func WithAggregatorURLs(urls []string) ClientOption { - return func(c *Client) { - if len(urls) > 0 { - c.AggregatorURL = urls - } - } + return func(c *Client) { + if len(urls) > 0 { + c.AggregatorURL = urls + } + } } // WithPublisherURLs sets custom publisher URLs for the client func WithPublisherURLs(urls []string) ClientOption { - return func(c *Client) { - if len(urls) > 0 { - c.PublisherURL = urls - } - } + return func(c *Client) { + if len(urls) > 0 { + c.PublisherURL = urls + } + } } // WithHTTPClient sets a custom HTTP client for the Walrus client func WithHTTPClient(httpClient *http.Client) ClientOption { - return func(c *Client) { - if httpClient != nil { - c.httpClient = httpClient - } - } + return func(c *Client) { + if httpClient != nil { + c.httpClient = httpClient + } + } } // WithRetryConfig sets the retry configuration for the client func WithRetryConfig(maxRetries int, retryDelay time.Duration) ClientOption { - return func(c *Client) { - c.retryConfig = RetryConfig{ - MaxRetries: maxRetries, - RetryDelay: retryDelay, - } - } + return func(c *Client) { + c.retryConfig = RetryConfig{ + MaxRetries: maxRetries, + RetryDelay: retryDelay, + } + } } // WithMaxUnknownLengthUploadSize sets the maximum allowed size for uploads when content length @@ -79,436 +81,483 @@ func WithRetryConfig(maxRetries int, retryDelay time.Duration) ClientOption { // This limit helps prevent potential memory exhaustion in such cases. // Default is 5MB. func WithMaxUnknownLengthUploadSize(maxSize int64) ClientOption { - return func(c *Client) { - if maxSize > 0 { - c.MaxUnknownLengthUploadSize = maxSize - } - } + return func(c *Client) { + if maxSize > 0 { + c.MaxUnknownLengthUploadSize = maxSize + } + } } // NewClient creates a new Walrus client with the specified options func NewClient(opts ...ClientOption) *Client { - // Create client with default values - client := &Client{ - AggregatorURL: DefaultTestnetAggregators, - PublisherURL: DefaultTestnetPublishers, - httpClient: &http.Client{}, - retryConfig: RetryConfig{ - MaxRetries: 5, // Default to 5 retries - RetryDelay: 500 * time.Millisecond, // Default to 500ms delay - }, - MaxUnknownLengthUploadSize: 5 * 1024 * 1024, // Default to 5MB - } - - // Apply all options - for _, opt := range opts { - opt(client) - } - - return client + // Create client with default values + client := &Client{ + AggregatorURL: DefaultTestnetAggregators, + PublisherURL: DefaultTestnetPublishers, + httpClient: &http.Client{}, + retryConfig: RetryConfig{ + MaxRetries: 5, // Default to 5 retries + RetryDelay: 500 * time.Millisecond, // Default to 500ms delay + }, + MaxUnknownLengthUploadSize: 5 * 1024 * 1024, // Default to 5MB + } + + // Apply all options + for _, opt := range opts { + opt(client) + } + + return client } // EncryptionOptions defines the encryption configuration type EncryptionOptions struct { - // Key used for encryption/decryption - Key []byte + // Key used for encryption/decryption + Key []byte + // Mode specifies the encryption mode ("CBC" or "GCM") + Mode string + // IV is only required for CBC mode + IV []byte } // StoreOptions defines options for storing data type StoreOptions struct { - Epochs int // Number of storage epochs - // Encryption configuration, if nil encryption is disabled - Encryption *EncryptionOptions + Epochs int // Number of storage epochs + // Encryption configuration, if nil encryption is disabled + Encryption *EncryptionOptions } // ReadOptions defines options for reading data type ReadOptions struct { - // Encryption configuration for decryption, if nil decryption is disabled - Encryption *EncryptionOptions + // Encryption configuration for decryption, if nil decryption is disabled + Encryption *EncryptionOptions } // BlobInfo represents the information returned after storing data type BlobInfo struct { - BlobID string `json:"blobId"` - EndEpoch int `json:"endEpoch"` + BlobID string `json:"blobId"` + EndEpoch int `json:"endEpoch"` } // BlobObject represents the blob object information type BlobObject struct { - ID string `json:"id"` - StoredEpoch int `json:"storedEpoch"` - BlobID string `json:"blobId"` - Size int64 `json:"size"` - ErasureCodeType string `json:"erasureCodeType"` - CertifiedEpoch int `json:"certifiedEpoch"` - Storage StorageInfo `json:"storage"` + ID string `json:"id"` + StoredEpoch int `json:"storedEpoch"` + BlobID string `json:"blobId"` + Size int64 `json:"size"` + ErasureCodeType string `json:"erasureCodeType"` + CertifiedEpoch int `json:"certifiedEpoch"` + Storage StorageInfo `json:"storage"` } // StoreResponse represents the unified response for store operations type StoreResponse struct { - Blob BlobInfo `json:"blobInfo,omitempty"` - - // For newly created blobs - NewlyCreated *struct { - BlobObject BlobObject `json:"blobObject"` - EncodedSize int `json:"encodedSize"` - Cost int `json:"cost"` - } `json:"newlyCreated,omitempty"` - - // For already certified blobs - AlreadyCertified *struct { - BlobID string `json:"blobId"` - Event EventInfo `json:"event"` - EndEpoch int `json:"endEpoch"` - } `json:"alreadyCertified,omitempty"` + Blob BlobInfo `json:"blobInfo,omitempty"` + + // For newly created blobs + NewlyCreated *struct { + BlobObject BlobObject `json:"blobObject"` + EncodedSize int `json:"encodedSize"` + Cost int `json:"cost"` + } `json:"newlyCreated,omitempty"` + + // For already certified blobs + AlreadyCertified *struct { + BlobID string `json:"blobId"` + Event EventInfo `json:"event"` + EndEpoch int `json:"endEpoch"` + } `json:"alreadyCertified,omitempty"` } // NormalizeBlobResponse is a helper function to normalize the response from the blob service func (resp *StoreResponse) NormalizeBlobResponse() { - if resp.AlreadyCertified != nil { - resp.Blob.BlobID = resp.AlreadyCertified.BlobID - resp.Blob.EndEpoch = resp.AlreadyCertified.EndEpoch - } - - if resp.NewlyCreated != nil { - resp.Blob.BlobID = resp.NewlyCreated.BlobObject.BlobID - resp.Blob.EndEpoch = resp.NewlyCreated.BlobObject.Storage.EndEpoch - } + if resp.AlreadyCertified != nil { + resp.Blob.BlobID = resp.AlreadyCertified.BlobID + resp.Blob.EndEpoch = resp.AlreadyCertified.EndEpoch + } + + if resp.NewlyCreated != nil { + resp.Blob.BlobID = resp.NewlyCreated.BlobObject.BlobID + resp.Blob.EndEpoch = resp.NewlyCreated.BlobObject.Storage.EndEpoch + } } // EventInfo represents the certification event information type EventInfo struct { - TxDigest string `json:"txDigest"` - EventSeq string `json:"eventSeq"` + TxDigest string `json:"txDigest"` + EventSeq string `json:"eventSeq"` } // StorageInfo represents the storage information for a blob type StorageInfo struct { - ID string `json:"id"` - StartEpoch int `json:"startEpoch"` - EndEpoch int `json:"endEpoch"` - StorageSize int `json:"storageSize"` + ID string `json:"id"` + StartEpoch int `json:"startEpoch"` + EndEpoch int `json:"endEpoch"` + StorageSize int `json:"storageSize"` } // BlobMetadata represents the metadata information returned by Head request type BlobMetadata struct { - ContentLength int64 `json:"content-length"` - ContentType string `json:"content-type"` - LastModified string `json:"last-modified"` - ETag string `json:"etag"` + ContentLength int64 `json:"content-length"` + ContentType string `json:"content-type"` + LastModified string `json:"last-modified"` + ETag string `json:"etag"` +} + +// Add a helper function to create cipher +func (opts *EncryptionOptions) getCipher() (encryption.StreamCipher, error) { + if opts == nil || len(opts.Key) == 0 { + return nil, fmt.Errorf("encryption key is required") + } + + switch opts.Mode { + case "CBC": + if len(opts.IV) == 0 { + return nil, fmt.Errorf("IV is required for CBC mode") + } + return encryption.NewCBCCipher(opts.Key, opts.IV) + case "GCM", "": // Default to GCM if no mode is specified + return encryption.NewGCMCipher(opts.Key) + default: + return nil, fmt.Errorf("unsupported encryption mode: %s", opts.Mode) + } } // Store stores data on the Walrus Publisher and returns the complete store response func (c *Client) Store(data []byte, opts *StoreOptions) (*StoreResponse, error) { - urlStr := "/v1/store" - if opts != nil && opts.Epochs > 0 { - urlStr += "?epochs=" + strconv.Itoa(opts.Epochs) - } - - var reader io.Reader = bytes.NewReader(data) - - // If encryption is enabled - if opts != nil && opts.Encryption != nil { - var buf bytes.Buffer - if err := EncryptStream(opts.Encryption.Key, bytes.NewReader(data), &buf); err != nil { - return nil, fmt.Errorf("failed to encrypt data: %w", err) - } - reader = &buf - } - - req, err := http.NewRequest("PUT", urlStr, reader) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/octet-stream") - - resp, err := c.doWithRetry(req, c.PublisherURL) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respData, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - var storeResp StoreResponse - if err := json.Unmarshal(respData, &storeResp); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - storeResp.NormalizeBlobResponse() - - return &storeResp, nil + urlStr := "/v1/store" + if opts != nil && opts.Epochs > 0 { + urlStr += "?epochs=" + strconv.Itoa(opts.Epochs) + } + + var reader io.Reader = bytes.NewReader(data) + + // If encryption is enabled + if opts != nil && opts.Encryption != nil { + cipher, err := opts.Encryption.getCipher() + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + var buf bytes.Buffer + if err := cipher.EncryptStream(bytes.NewReader(data), &buf); err != nil { + return nil, fmt.Errorf("failed to encrypt data: %w", err) + } + reader = &buf + } + + req, err := http.NewRequest("PUT", urlStr, reader) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := c.doWithRetry(req, c.PublisherURL) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respData, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var storeResp StoreResponse + if err := json.Unmarshal(respData, &storeResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + storeResp.NormalizeBlobResponse() + + return &storeResp, nil } // StoreFromReader stores data from an io.Reader and returns the complete store response func (c *Client) StoreFromReader(reader io.Reader, opts *StoreOptions) (*StoreResponse, error) { - urlStr := "/v1/store" - if opts != nil && opts.Epochs > 0 { - urlStr += "?epochs=" + strconv.Itoa(opts.Epochs) - } - - var err error - - // If encryption is enabled - if opts != nil && opts.Encryption != nil { - var buf bytes.Buffer - if err := EncryptStream(opts.Encryption.Key, reader, &buf); err != nil { - return nil, fmt.Errorf("failed to encrypt data: %w", err) - } - reader = &buf - } - - // Create request with the proper reader - req, err := http.NewRequest("PUT", urlStr, reader) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/octet-stream") - - resp, err := c.doWithRetry(req, c.PublisherURL) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respData, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - var storeResp StoreResponse - if err := json.Unmarshal(respData, &storeResp); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - storeResp.NormalizeBlobResponse() - return &storeResp, nil + urlStr := "/v1/store" + if opts != nil && opts.Epochs > 0 { + urlStr += "?epochs=" + strconv.Itoa(opts.Epochs) + } + + var err error + + // If encryption is enabled + if opts != nil && opts.Encryption != nil { + cipher, err := opts.Encryption.getCipher() + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + var buf bytes.Buffer + if err := cipher.EncryptStream(reader, &buf); err != nil { + return nil, fmt.Errorf("failed to encrypt data: %w", err) + } + reader = &buf + } + + // Create request with the proper reader + req, err := http.NewRequest("PUT", urlStr, reader) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := c.doWithRetry(req, c.PublisherURL) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respData, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var storeResp StoreResponse + if err := json.Unmarshal(respData, &storeResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + storeResp.NormalizeBlobResponse() + return &storeResp, nil } // StoreFromURL downloads and stores content from URL and returns the complete store response func (c *Client) StoreFromURL(sourceURL string, opts *StoreOptions) (*StoreResponse, error) { - req, err := http.NewRequest("GET", sourceURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to download from URL: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("failed to download from URL %s: HTTP request returned status code %d, expected 200 OK", sourceURL, resp.StatusCode) - } - - return c.StoreFromReader(resp.Body, opts) + req, err := http.NewRequest("GET", sourceURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to download from URL: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to download from URL %s: HTTP request returned status code %d, expected 200 OK", sourceURL, resp.StatusCode) + } + + return c.StoreFromReader(resp.Body, opts) } // StoreFile stores a file and returns the complete store response func (c *Client) StoreFile(filePath string, opts *StoreOptions) (*StoreResponse, error) { - file, err := os.Open(filePath) - if err != nil { - return nil, err - } - defer file.Close() + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer file.Close() - return c.StoreFromReader(file, opts) + return c.StoreFromReader(file, opts) } // Read retrieves a blob from the Walrus Aggregator func (c *Client) Read(blobID string, opts *ReadOptions) ([]byte, error) { - urlStr := fmt.Sprintf("/v1/%s", url.PathEscape(blobID)) - - req, err := http.NewRequest(http.MethodGet, urlStr, nil) - if err != nil { - return nil, err - } - - resp, err := c.doWithRetry(req, c.AggregatorURL) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - // If decryption is enabled - if opts != nil && opts.Encryption != nil { - var decryptedBuf bytes.Buffer - if err := DecryptStream(opts.Encryption.Key, resp.Body, &decryptedBuf); err != nil { - return nil, fmt.Errorf("failed to decrypt data: %w", err) - } - return decryptedBuf.Bytes(), nil - } - - return io.ReadAll(resp.Body) + urlStr := fmt.Sprintf("/v1/%s", url.PathEscape(blobID)) + + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + if err != nil { + return nil, err + } + + resp, err := c.doWithRetry(req, c.AggregatorURL) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // If decryption is enabled + if opts != nil && opts.Encryption != nil { + cipher, err := opts.Encryption.getCipher() + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + var decryptedBuf bytes.Buffer + if err := cipher.DecryptStream(resp.Body, &decryptedBuf); err != nil { + return nil, fmt.Errorf("failed to decrypt data: %w", err) + } + return decryptedBuf.Bytes(), nil + } + + return io.ReadAll(resp.Body) } // ReadToFile retrieves a blob and writes it to a file func (c *Client) ReadToFile(blobID, filePath string, opts *ReadOptions) error { - urlStr := fmt.Sprintf("/v1/%s", url.PathEscape(blobID)) - - req, err := http.NewRequest(http.MethodGet, urlStr, nil) - if err != nil { - return err - } - - resp, err := c.doWithRetry(req, c.AggregatorURL) - if err != nil { - return err - } - defer resp.Body.Close() - - // Create the file - outFile, err := os.Create(filePath) - if err != nil { - return err - } - defer outFile.Close() - - // If decryption is enabled - if opts != nil && opts.Encryption != nil { - return DecryptStream(opts.Encryption.Key, resp.Body, outFile) - } - - // Write the response body to the file - _, err = io.Copy(outFile, resp.Body) - return err + urlStr := fmt.Sprintf("/v1/%s", url.PathEscape(blobID)) + + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + if err != nil { + return err + } + + resp, err := c.doWithRetry(req, c.AggregatorURL) + if err != nil { + return err + } + defer resp.Body.Close() + + // Create the file + outFile, err := os.Create(filePath) + if err != nil { + return err + } + defer outFile.Close() + + // If decryption is enabled + if opts != nil && opts.Encryption != nil { + cipher, err := opts.Encryption.getCipher() + if err != nil { + return fmt.Errorf("failed to create cipher: %w", err) + } + return cipher.DecryptStream(resp.Body, outFile) + } + + // Write the response body to the file + _, err = io.Copy(outFile, resp.Body) + return err } // GetAPISpec retrieves the API specification from the aggregator or publisher func (c *Client) GetAPISpec(isAggregator bool) ([]byte, error) { - urlStr := "/v1/api" - - req, err := http.NewRequest(http.MethodGet, urlStr, nil) - if err != nil { - return nil, err - } - - urls := c.PublisherURL - if isAggregator { - urls = c.AggregatorURL - } - - resp, err := c.doWithRetry(req, urls) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - return io.ReadAll(resp.Body) + urlStr := "/v1/api" + + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + if err != nil { + return nil, err + } + + urls := c.PublisherURL + if isAggregator { + urls = c.AggregatorURL + } + + resp, err := c.doWithRetry(req, urls) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return io.ReadAll(resp.Body) } // Head retrieves blob metadata from the Walrus Aggregator without downloading the content func (c *Client) Head(blobID string) (*BlobMetadata, error) { - urlStr := fmt.Sprintf("/v1/%s", url.PathEscape(blobID)) - - req, err := http.NewRequest(http.MethodHead, urlStr, nil) - if err != nil { - return nil, fmt.Errorf("failed to create HEAD request: %w", err) - } - - resp, err := c.doWithRetry(req, c.AggregatorURL) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - metadata := &BlobMetadata{ - ContentLength: resp.ContentLength, - ContentType: resp.Header.Get("Content-Type"), - LastModified: resp.Header.Get("Last-Modified"), - ETag: resp.Header.Get("ETag"), - } - - return metadata, nil + urlStr := fmt.Sprintf("/v1/%s", url.PathEscape(blobID)) + + req, err := http.NewRequest(http.MethodHead, urlStr, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HEAD request: %w", err) + } + + resp, err := c.doWithRetry(req, c.AggregatorURL) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + metadata := &BlobMetadata{ + ContentLength: resp.ContentLength, + ContentType: resp.Header.Get("Content-Type"), + LastModified: resp.Header.Get("Last-Modified"), + ETag: resp.Header.Get("ETag"), + } + + return metadata, nil } // ReadToReader retrieves a blob and writes it to the provided io.Writer -func (c *Client) ReadToReader(blobID string, options *ReadOptions) (io.ReadCloser, error) { - urlStr := fmt.Sprintf("/v1/%s", url.PathEscape(blobID)) - - req, err := http.NewRequest(http.MethodGet, urlStr, nil) - if err != nil { - return nil, err - } - - resp, err := c.doWithRetry(req, c.AggregatorURL) - if err != nil { - return nil, err - } - - // If decryption is enabled - if options != nil && options.Encryption != nil { - var decryptedBuf bytes.Buffer - if err := DecryptStream(options.Encryption.Key, resp.Body, &decryptedBuf); err != nil { - return nil, fmt.Errorf("failed to decrypt data: %w", err) - } - return io.NopCloser(&decryptedBuf), nil - } - - return resp.Body, nil +func (c *Client) ReadToReader(blobID string, opts *ReadOptions) (io.ReadCloser, error) { + urlStr := fmt.Sprintf("/v1/%s", url.PathEscape(blobID)) + + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + if err != nil { + return nil, err + } + + resp, err := c.doWithRetry(req, c.AggregatorURL) + if err != nil { + return nil, err + } + + // If decryption is enabled + if opts != nil && opts.Encryption != nil { + cipher, err := opts.Encryption.getCipher() + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + var decryptedBuf bytes.Buffer + if err := cipher.DecryptStream(resp.Body, &decryptedBuf); err != nil { + return nil, fmt.Errorf("failed to decrypt data: %w", err) + } + return io.NopCloser(&decryptedBuf), nil + } + + return resp.Body, nil } // doWithRetry performs an HTTP request with retry logic func (c *Client) doWithRetry(req *http.Request, urls []string) (*http.Response, error) { - var lastErr error - // Calculate total attempts based on retry config and URL count - totalAttempts := c.retryConfig.MaxRetries + 1 - attemptCount := 0 - - // Try URLs in round-robin fashion until max retries reached - for attemptCount < totalAttempts { - // Get URL index for this attempt - urlIndex := attemptCount % len(urls) - baseURL := urls[urlIndex] - - // Update request URL with current base URL - req.URL.Host = "" - req.URL.Scheme = "" - fullURL := baseURL + req.URL.String() - req.URL, _ = url.Parse(fullURL) - - // Create a new request for this attempt (since the original body might have been consumed) - newReq := &http.Request{} - *newReq = *req - if req.Body != nil { - bodyBytes, err := io.ReadAll(req.Body) - if err != nil { - return nil, fmt.Errorf("failed to read request body: %w", err) - } - req.Body.Close() - newReq.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - resp, err := c.httpClient.Do(newReq) - if err == nil && resp.StatusCode == http.StatusOK { - return resp, nil - } - - if err != nil { - lastErr = err - } else { - // Attempt to read error message from response body for better error reporting - errBody, readErr := io.ReadAll(resp.Body) - resp.Body.Close() - if readErr == nil && len(errBody) > 0 { - lastErr = fmt.Errorf("request failed with status code %d: %s", resp.StatusCode, string(errBody)) - } else { - lastErr = fmt.Errorf("request failed with status code %d", resp.StatusCode) - } - } - - // Sleep before next attempt if not the last attempt - if attemptCount < totalAttempts-1 { - time.Sleep(c.retryConfig.RetryDelay) - } - - attemptCount++ - } - - return nil, fmt.Errorf("all retry attempts failed: %w", lastErr) + var lastErr error + // Calculate total attempts based on retry config and URL count + totalAttempts := c.retryConfig.MaxRetries + 1 + attemptCount := 0 + + // Try URLs in round-robin fashion until max retries reached + for attemptCount < totalAttempts { + // Get URL index for this attempt + urlIndex := attemptCount % len(urls) + baseURL := urls[urlIndex] + + // Update request URL with current base URL + req.URL.Host = "" + req.URL.Scheme = "" + fullURL := baseURL + req.URL.String() + req.URL, _ = url.Parse(fullURL) + + // Create a new request for this attempt (since the original body might have been consumed) + newReq := &http.Request{} + *newReq = *req + if req.Body != nil { + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + req.Body.Close() + newReq.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + resp, err := c.httpClient.Do(newReq) + if err == nil && resp.StatusCode == http.StatusOK { + return resp, nil + } + + if err != nil { + lastErr = err + } else { + // Attempt to read error message from response body for better error reporting + errBody, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + if readErr == nil && len(errBody) > 0 { + lastErr = fmt.Errorf("request failed with status code %d: %s", resp.StatusCode, string(errBody)) + } else { + lastErr = fmt.Errorf("request failed with status code %d", resp.StatusCode) + } + } + + // Sleep before next attempt if not the last attempt + if attemptCount < totalAttempts-1 { + time.Sleep(c.retryConfig.RetryDelay) + } + + attemptCount++ + } + + return nil, fmt.Errorf("all retry attempts failed: %w", lastErr) } diff --git a/walrus_test.go b/walrus_test.go index 558e171..5763dc5 100644 --- a/walrus_test.go +++ b/walrus_test.go @@ -432,56 +432,64 @@ func TestRequestBodyPreservation(t *testing.T) { } } -// TestEncryption tests the encryption and decryption functionality +// TestEncryption tests both CBC and GCM encryption modes func TestEncryption(t *testing.T) { client := newTestClient(t) testData := []byte("Hello, Encrypted World!") - key := make([]byte, 32) // AES-256 key - rand.Read(key) - - // Test storing with encryption - storeOpts := &StoreOptions{ - Epochs: 1, - Encryption: &EncryptionOptions{ - Key: key, - }, - } - resp, err := client.Store(testData, storeOpts) - if err != nil { - t.Fatalf("Failed to store encrypted data: %v", err) - } - resp.NormalizeBlobResponse() - blobID := resp.Blob.BlobID - - // Test cases for reading + // Create test cases for each encryption mode tests := []struct { name string + storeOpts *StoreOptions readOpts *ReadOptions shouldMatch bool expectErr bool }{ { - name: "read with correct key", - readOpts: &ReadOptions{ + name: "GCM mode - correct key", + storeOpts: &StoreOptions{ + Epochs: 1, Encryption: &EncryptionOptions{ - Key: key, + Key: make([]byte, 32), // Will be filled with random data + Mode: "GCM", }, }, shouldMatch: true, expectErr: false, }, { - name: "read without decryption", - readOpts: nil, - shouldMatch: false, + name: "CBC mode - correct key", + storeOpts: &StoreOptions{ + Epochs: 1, + Encryption: &EncryptionOptions{ + Key: make([]byte, 32), // Will be filled with random data + Mode: "CBC", + IV: make([]byte, 16), // Will be filled with random data + }, + }, + shouldMatch: true, expectErr: false, }, { - name: "read with wrong key", - readOpts: &ReadOptions{ + name: "CBC mode - missing IV", + storeOpts: &StoreOptions{ + Epochs: 1, + Encryption: &EncryptionOptions{ + Key: make([]byte, 32), + Mode: "CBC", + // Missing IV + }, + }, + shouldMatch: false, + expectErr: true, + }, + { + name: "Invalid mode", + storeOpts: &StoreOptions{ + Epochs: 1, Encryption: &EncryptionOptions{ - Key: make([]byte, 32), // Different key + Key: make([]byte, 32), + Mode: "invalid", }, }, shouldMatch: false, @@ -491,17 +499,42 @@ func TestEncryption(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - retrieved, err := client.Read(blobID, tt.readOpts) - + // Generate random key and IV if needed + if tt.storeOpts != nil && tt.storeOpts.Encryption != nil { + rand.Read(tt.storeOpts.Encryption.Key) + if tt.storeOpts.Encryption.Mode == "CBC" && tt.storeOpts.Encryption.IV != nil { + rand.Read(tt.storeOpts.Encryption.IV) + } + } + + // Store with encryption + resp, err := client.Store(testData, tt.storeOpts) if tt.expectErr { if err == nil { t.Error("Expected error but got none") } return } + if err != nil { + t.Fatalf("Failed to store encrypted data: %v", err) + } + + resp.NormalizeBlobResponse() + blobID := resp.Blob.BlobID + // Create matching read options + readOpts := &ReadOptions{ + Encryption: &EncryptionOptions{ + Key: tt.storeOpts.Encryption.Key, + Mode: tt.storeOpts.Encryption.Mode, + IV: tt.storeOpts.Encryption.IV, + }, + } + + // Read with decryption + retrieved, err := client.Read(blobID, readOpts) if err != nil { - t.Fatalf("Failed to read data: %v", err) + t.Fatalf("Failed to read encrypted data: %v", err) } if tt.shouldMatch { @@ -509,10 +542,138 @@ func TestEncryption(t *testing.T) { t.Errorf("Retrieved data doesn't match original.\nExpected: %s\nGot: %s", string(testData), string(retrieved)) } - } else { - if bytes.Equal(retrieved, testData) { - t.Error("Retrieved data matches original when it shouldn't") - } + } + }) + } +} + +// TestEncryptionModeErrors tests error handling for different encryption modes +func TestEncryptionModeErrors(t *testing.T) { + client := newTestClient(t) + testData := []byte("Test Data") + + tests := []struct { + name string + opts *StoreOptions + errorMsg string + }{ + { + name: "CBC without IV", + opts: &StoreOptions{ + Encryption: &EncryptionOptions{ + Key: make([]byte, 32), + Mode: "CBC", + }, + }, + errorMsg: "IV is required for CBC mode", + }, + { + name: "Invalid mode", + opts: &StoreOptions{ + Encryption: &EncryptionOptions{ + Key: make([]byte, 32), + Mode: "XYZ", + }, + }, + errorMsg: "unsupported encryption mode: XYZ", + }, + { + name: "Empty key", + opts: &StoreOptions{ + Encryption: &EncryptionOptions{ + Mode: "GCM", + }, + }, + errorMsg: "encryption key is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.opts.Encryption.Key != nil { + rand.Read(tt.opts.Encryption.Key) + } + + _, err := client.Store(testData, tt.opts) + if err == nil { + t.Error("Expected error but got none") + return + } + + if !strings.Contains(err.Error(), tt.errorMsg) { + t.Errorf("Expected error containing '%s', got '%s'", tt.errorMsg, err.Error()) + } + }) + } +} + +// TestEncryptionLargeFile tests encryption and decryption of large files +func TestEncryptionLargeFile(t *testing.T) { + client := newTestClient(t) + + // Create 1MB test data + testData := make([]byte, 1024*1024) + rand.Read(testData) + + modes := []struct { + name string + opts *StoreOptions + }{ + { + name: "GCM mode", + opts: &StoreOptions{ + Encryption: &EncryptionOptions{ + Key: make([]byte, 32), + Mode: "GCM", + }, + }, + }, + { + name: "CBC mode", + opts: &StoreOptions{ + Encryption: &EncryptionOptions{ + Key: make([]byte, 32), + Mode: "CBC", + IV: make([]byte, 16), + }, + }, + }, + } + + for _, mode := range modes { + t.Run(mode.name, func(t *testing.T) { + // Generate random key and IV + rand.Read(mode.opts.Encryption.Key) + if mode.opts.Encryption.Mode == "CBC" { + rand.Read(mode.opts.Encryption.IV) + } + + // Store encrypted data + resp, err := client.Store(testData, mode.opts) + if err != nil { + t.Fatalf("Failed to store encrypted data: %v", err) + } + + resp.NormalizeBlobResponse() + blobID := resp.Blob.BlobID + + // Create matching read options + readOpts := &ReadOptions{ + Encryption: &EncryptionOptions{ + Key: mode.opts.Encryption.Key, + Mode: mode.opts.Encryption.Mode, + IV: mode.opts.Encryption.IV, + }, + } + + // Read and decrypt data + retrieved, err := client.Read(blobID, readOpts) + if err != nil { + t.Fatalf("Failed to read encrypted data: %v", err) + } + + if !bytes.Equal(retrieved, testData) { + t.Error("Retrieved data doesn't match original") } }) }