From adec74cf1f156521122f5d9113b5fc609115abbf Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 09:19:34 +0000 Subject: [PATCH] Add quantum neural network and scientific domains --- src/NeuroFlex/quantum_nn_module.py | 177 ++++++++++++++++++ src/NeuroFlex/scientific_domains/__init__.py | 0 .../scientific_domains/quantum_domains.py | 14 ++ src/NeuroFlex/tensorflow_module.py | 52 +++++ src/NeuroFlex/tokenizer.py | 164 ++++++++++++++++ src/NeuroFlex/vision_transformer.py | 53 ++++++ tests/test_self_curing_algorithm.py | 55 ++++++ tests/test_tokenizer.py | 77 ++++++++ 8 files changed, 592 insertions(+) create mode 100644 src/NeuroFlex/quantum_nn_module.py create mode 100644 src/NeuroFlex/scientific_domains/__init__.py create mode 100644 src/NeuroFlex/scientific_domains/quantum_domains.py create mode 100644 src/NeuroFlex/tensorflow_module.py create mode 100644 src/NeuroFlex/tokenizer.py create mode 100644 src/NeuroFlex/vision_transformer.py create mode 100644 tests/test_self_curing_algorithm.py create mode 100644 tests/test_tokenizer.py diff --git a/src/NeuroFlex/quantum_nn_module.py b/src/NeuroFlex/quantum_nn_module.py new file mode 100644 index 0000000..aedd742 --- /dev/null +++ b/src/NeuroFlex/quantum_nn_module.py @@ -0,0 +1,177 @@ +import jax +import jax.numpy as jnp +from jax import tree_util +import flax.linen as nn +import pennylane as qml +import logging +from typing import Callable, List, Tuple, Optional, Any, Dict +from functools import partial +from flax import struct + +class QuantumNeuralNetwork(nn.Module): + """ + A quantum neural network module that combines classical and quantum computations. + + This class implements a variational quantum circuit that can be used as a layer + in a hybrid quantum-classical neural network. + + Attributes: + num_qubits (int): The number of qubits in the quantum circuit. + num_layers (int): The number of layers in the variational quantum circuit. + input_shape (Tuple[int, ...]): The shape of the input tensor. + output_shape (Tuple[int, ...]): The shape of the output tensor (excluding batch dimension). + max_retries (int): The maximum number of retries for quantum circuit execution. + """ + + num_qubits: int + num_layers: int + input_shape: Tuple[int, ...] + output_shape: Tuple[int, ...] + max_retries: int = 3 + device: Optional[qml.Device] = None + qlayer: Optional[Callable] = None + vmap_qlayer: Optional[Callable] = None + + def setup(self): + logging.info(f"Setting up QuantumNeuralNetwork with {self.num_qubits} qubits, {self.num_layers} layers, input shape {self.input_shape}, and output shape {self.output_shape}") + self._validate_init_params() + + self.param('weights', nn.initializers.uniform(scale=0.1), (self.num_layers, self.num_qubits, 3)) + try: + quantum_components = self._initialize_quantum_components() + self.device = quantum_components['device'] + self.qlayer = quantum_components['qlayer'] + self.vmap_qlayer = quantum_components['vmap_qlayer'] + self.variable('quantum_components', 'components', lambda: quantum_components) + except Exception as e: + logging.error(f"Error initializing quantum components: {str(e)}") + fallback_components = self._fallback_initialization() + self.device = fallback_components['device'] + self.qlayer = fallback_components['qlayer'] + self.vmap_qlayer = fallback_components['vmap_qlayer'] + self.variable('quantum_components', 'components', lambda: fallback_components) + + def _validate_init_params(self): + if not isinstance(self.num_qubits, int) or self.num_qubits <= 0: + raise ValueError(f"Number of qubits must be a positive integer, got {self.num_qubits}") + if not isinstance(self.num_layers, int) or self.num_layers <= 0: + raise ValueError(f"Number of layers must be a positive integer, got {self.num_layers}") + if not isinstance(self.input_shape, tuple) or len(self.input_shape) != 2 or self.input_shape[1] != self.num_qubits: + raise ValueError(f"Invalid input_shape: {self.input_shape}. Expected shape (batch_size, {self.num_qubits})") + if not isinstance(self.output_shape, tuple) or len(self.output_shape) != 1 or self.output_shape[0] != self.num_qubits: + raise ValueError(f"Invalid output_shape: {self.output_shape}. Expected shape ({self.num_qubits},)") + + def _initialize_quantum_components(self): + try: + self.device = qml.device("default.qubit", wires=self.num_qubits) + self.qlayer = qml.QNode(self.quantum_circuit, self.device, interface="jax") + self.vmap_qlayer = jax.vmap(self.qlayer, in_axes=(0, None)) + logging.info("Quantum components created successfully") + return { + 'device': self.device, + 'qlayer': self.qlayer, + 'vmap_qlayer': self.vmap_qlayer + } + except Exception as e: + logging.error(f"Error creating quantum components: {str(e)}") + return self._fallback_initialization() + + def quantum_circuit(self, inputs: jnp.ndarray, weights: jnp.ndarray) -> List[qml.measurements.ExpectationMP]: + qml.AngleEmbedding(inputs, wires=range(self.num_qubits)) + for l in range(self.num_layers): + qml.StronglyEntanglingLayers(weights[l], wires=range(self.num_qubits)) + return [qml.expval(qml.PauliZ(i)) for i in range(self.num_qubits)] + + def validate_input_shape(self, x: jnp.ndarray) -> None: + if len(x.shape) != 2 or x.shape[1] != self.num_qubits: + raise ValueError(f"Input shape {x.shape} does not match expected shape (batch_size, {self.num_qubits})") + + def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> jnp.ndarray: + try: + self.validate_input_shape(x) + if jnp.any(jnp.isnan(x)) or jnp.any(jnp.isinf(x)): + raise ValueError(f"Input contains NaN or Inf values: {x}") + + logging.debug(f"Executing quantum circuit with input shape: {x.shape}") + if self.vmap_qlayer is None: + logging.warning("Quantum components not initialized. Attempting initialization.") + self._initialize_quantum_components() + if self.vmap_qlayer is None: + logging.error("Quantum components initialization failed. Using fallback.") + return self._fallback_output(x) + + result_array = self._execute_quantum_circuit(x) + + expected_shape = (x.shape[0],) + self.output_shape + if result_array.shape != expected_shape: + logging.warning(f"Output shape mismatch. Expected {expected_shape}, got {result_array.shape}. Reshaping.") + result_array = jnp.reshape(result_array, expected_shape) + + result_array = jnp.clip(result_array, -1, 1) + logging.info(f"Quantum circuit executed successfully. Input shape: {x.shape}, Output shape: {result_array.shape}") + return result_array + except ValueError as ve: + logging.error(f"ValueError during quantum circuit execution: {str(ve)}") + return self._fallback_output(x) + except Exception as e: + logging.error(f"Unexpected error during quantum circuit execution: {str(e)}") + return self._fallback_output(x) + + def _execute_quantum_circuit(self, x: jnp.ndarray) -> jnp.ndarray: + weights = self.variable('params', 'weights').value + for attempt in range(self.max_retries): + try: + logging.debug(f"Attempt {attempt + 1}/{self.max_retries} to execute quantum circuit") + if self.vmap_qlayer is None: + raise ValueError("Quantum components not properly initialized") + result = self.vmap_qlayer(x, weights) + result_array = jnp.array(result) + if jnp.all(jnp.isfinite(result_array)): + logging.info(f"Quantum circuit execution successful on attempt {attempt + 1}") + return result_array + else: + raise ValueError("Quantum circuit produced non-finite values") + except Exception as e: + logging.warning(f"Quantum circuit execution failed on attempt {attempt + 1}: {str(e)}") + if attempt == self.max_retries - 1: + logging.error("Max retries reached. Quantum circuit execution failed.") + return self._fallback_output(x) + return self._fallback_output(x) # Ensure a return value if loop completes + + def _fallback_output(self, x: jnp.ndarray) -> jnp.ndarray: + fallback = jnp.zeros((x.shape[0],) + self.output_shape) + noise = jax.random.normal(jax.random.PRNGKey(0), fallback.shape) * 0.1 + return jnp.clip(fallback + noise, -1, 1) + + def _fallback_initialization(self): + logging.warning("Falling back to classical initialization") + fallback_components = { + 'device': None, + 'qlayer': lambda x, w: jnp.zeros(self.output_shape), + 'vmap_qlayer': jax.vmap(lambda x, w: jnp.zeros(self.output_shape), in_axes=(0, None)) + } + logging.info("Classical fallback initialization completed") + self.sow('quantum_components', 'components', fallback_components) + return fallback_components + + def reinitialize_device(self): + try: + new_device = qml.device("default.qubit", wires=self.num_qubits) + new_qlayer = qml.QNode(self.quantum_circuit, new_device, interface="jax") + new_vmap_qlayer = jax.vmap(new_qlayer, in_axes=(0, None)) + new_components = { + 'device': new_device, + 'qlayer': new_qlayer, + 'vmap_qlayer': new_vmap_qlayer + } + self.variable('quantum_components', 'components', lambda: new_components) + logging.info("Quantum device reinitialized successfully") + except Exception as e: + logging.error(f"Error reinitializing quantum device: {str(e)}") + fallback_components = self._fallback_initialization() + self.variable('quantum_components', 'components', lambda: fallback_components) + return self.variable('quantum_components', 'components').value + +@partial(jax.jit, static_argnums=(0, 1, 2, 3)) +def create_quantum_nn(num_qubits: int, num_layers: int, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> QuantumNeuralNetwork: + return QuantumNeuralNetwork(num_qubits=num_qubits, num_layers=num_layers, input_shape=input_shape, output_shape=output_shape) diff --git a/src/NeuroFlex/scientific_domains/__init__.py b/src/NeuroFlex/scientific_domains/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/NeuroFlex/scientific_domains/quantum_domains.py b/src/NeuroFlex/scientific_domains/quantum_domains.py new file mode 100644 index 0000000..1f4f09e --- /dev/null +++ b/src/NeuroFlex/scientific_domains/quantum_domains.py @@ -0,0 +1,14 @@ +import jax.numpy as jnp + +class QuantumDomains: + def __init__(self): + # Placeholder initialization + pass + + def simulate(self, state): + # Placeholder quantum simulation + return jnp.array(state) + + def measure(self, state): + # Placeholder measurement + return jnp.abs(state)**2 diff --git a/src/NeuroFlex/tensorflow_module.py b/src/NeuroFlex/tensorflow_module.py new file mode 100644 index 0000000..31beb2f --- /dev/null +++ b/src/NeuroFlex/tensorflow_module.py @@ -0,0 +1,52 @@ +# TensorFlow specific implementations will go here + +import tensorflow as tf +import keras + +# Example model using TensorFlow +class TensorFlowModel(keras.Model): + def __init__(self, features): + super(TensorFlowModel, self).__init__() + self.layers_ = keras.Sequential([ + keras.layers.Dense(100, activation='relu'), + keras.layers.Dense(features), + ]) + + def call(self, inputs): + return self.layers_(inputs) + +# Training function +@tf.function +def train_tf_model(model, X, y, epochs=10, lr=0.001): + optimizer = keras.optimizers.Adam(learning_rate=lr) + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + + @tf.function + def train_step(x, y): + with tf.GradientTape() as tape: + logits = model(x, training=True) + loss = loss_fn(y, logits) + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + return loss + + for epoch in range(epochs): + loss = train_step(X, y) + if epoch % 10 == 0: + print(f"Epoch {epoch}, Loss: {loss.numpy()}") + + return model + +# Decorator for distributed training +def distribute(strategy): + def decorator(func): + def wrapper(*args, **kwargs): + return strategy.run(func, args=args, kwargs=kwargs) + return wrapper + return decorator + +# Example usage of distribute decorator +# @distribute(tf.distribute.MirroredStrategy()) +# def distributed_train_step(model, x, y): +# # Your distributed training logic here +# pass diff --git a/src/NeuroFlex/tokenizer.py b/src/NeuroFlex/tokenizer.py new file mode 100644 index 0000000..858eb77 --- /dev/null +++ b/src/NeuroFlex/tokenizer.py @@ -0,0 +1,164 @@ +import os +import logging +from typing import List, Optional + +import sentencepiece + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class Tokenizer: + + def __init__(self, model_path: Optional[str]): + if not model_path or not os.path.isfile(model_path): + raise ValueError(f"Invalid model path: {model_path}") + try: + self.sp_model = sentencepiece.SentencePieceProcessor() + self.sp_model.Load(model_path) + self.n_words: int = self.sp_model.GetPieceSize() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + self.unk_id: int = self.sp_model.unk_id() + self.space_id: int = self.sp_model.PieceToId('') + self.special_chars = set('.,!?;:()[]{}""''') + logging.info(f"Tokenizer initialized successfully with {self.n_words} tokens") + except Exception as e: + logging.error(f"Error initializing tokenizer: {str(e)}") + raise RuntimeError(f"Tokenizer initialization failed: {str(e)}") + + def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]: + """ + Converts a string into a list of tokens. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sentence token. Defaults to True. + eos (bool): Whether to append the end-of-sentence token. Defaults to False. + + Returns: + List[int]: A list of token IDs. + + Raises: + ValueError: If the input is not a string or if encoding fails. + """ + if not isinstance(s, str): + raise ValueError(f"Invalid input type: {type(s)}. Input must be a string.") + + tokens = [] + if bos and self.bos_id != -1: + tokens.append(self.bos_id) + + if s: + try: + parts = self._split_text(s) + for part in parts: + if part.strip() or part in self.special_chars: + if part in self.special_chars: + tokens.append(self.sp_model.PieceToId(f"▁{part}")) + else: + encoded = self.sp_model.EncodeAsIds(part) + if not encoded: + logging.warning(f"Failed to encode part: '{part}'") + encoded = [self.sp_model.PieceToId(char) if char not in self.special_chars + else self.sp_model.PieceToId(f"▁{char}") + for char in part] + tokens.extend(encoded) + elif part.isspace(): + tokens.append(self.space_id) + except Exception as e: + logging.error(f"Error encoding '{s}': {str(e)}") + raise ValueError(f"Encoding failed: {str(e)}") + elif not bos and not eos: + logging.debug("Empty input with no BOS/EOS tokens requested, returning empty list") + return [] + + if eos and self.eos_id != -1: + tokens.append(self.eos_id) + + tokens = [t if t != self.unk_id else self.sp_model.PieceToId('') for t in tokens] + + logging.debug(f"Encoded '{s}' to {len(tokens)} tokens") + return tokens + + def decode(self, t: List[int]) -> str: + """Converts a list of tokens into a string.""" + if not isinstance(t, list) or not t or not all(isinstance(token, int) for token in t): + logging.warning(f"Invalid input for decoding: {t}") + return "" + + try: + t = self._handle_special_tokens(t) + if not t: + return "" + decoded_text = self.sp_model.DecodeIds(t) + decoded_text = self._post_process_decoded_text(decoded_text) + if not decoded_text: + logging.warning("Decoding resulted in empty string. Using fallback method.") + decoded_text = self._fallback_decode(t) + if not decoded_text: + logging.warning("Fallback decoding also resulted in empty string.") + return "[DECODING_FAILED]" + logging.debug(f"Decoded {len(t)} tokens to: '{decoded_text}'") + return decoded_text + except Exception as e: + logging.error(f"Error during decoding: {str(e)}") + logging.debug(f"Problematic tokens: {t}") + return self._fallback_decode(t) or "[DECODING_FAILED]" + + def _fallback_decode(self, t: List[int]) -> str: + """Fallback method for decoding when the main method fails.""" + try: + return ''.join([self.sp_model.IdToPiece(token) for token in t if token != self.unk_id]) + except Exception as e: + logging.error(f"Fallback decoding failed: {str(e)}") + return "" + + def _handle_special_tokens(self, tokens: List[int]) -> List[int]: + """Handles special tokens like BOS, EOS, and PAD.""" + return [token for token in tokens if token not in {self.bos_id, self.eos_id, self.pad_id} and token != -1] + + def _post_process_decoded_text(self, text: str) -> str: + """Post-processes the decoded text to improve readability.""" + for punct in self.special_chars: + text = text.replace(f' {punct}', punct) + return ' '.join(text.split()).strip() + + def _split_text(self, text: str) -> List[str]: + """Splits text into parts, preserving special characters and whitespace.""" + parts = [] + current_part = "" + for char in text: + if char.isspace() or char in self.special_chars: + if current_part: + parts.append(current_part) + current_part = "" + parts.append(char) + else: + current_part += char + if current_part: + parts.append(current_part) + return parts + + def tokenize(self, text: str) -> List[str]: + """Tokenizes the input text into a list of token strings.""" + if not isinstance(text, str): + logging.warning(f"Invalid input type for tokenization: {type(text)}. Expected string.") + return [] + try: + tokens = self.sp_model.EncodeAsPieces(text) + return [token if token != '▁' else ' ' for token in tokens] + except Exception as e: + logging.error(f"Error tokenizing '{text}': {str(e)}") + return [] + + def detokenize(self, tokens: List[str]) -> str: + """Converts a list of token strings back into text.""" + if not isinstance(tokens, list) or not all(isinstance(token, str) for token in tokens): + logging.warning(f"Invalid input for detokenization: {tokens}") + return "" + try: + text = self.sp_model.DecodePieces(tokens) + return self._post_process_decoded_text(text) + except Exception as e: + logging.error(f"Error detokenizing tokens: {str(e)}") + return "" diff --git a/src/NeuroFlex/vision_transformer.py b/src/NeuroFlex/vision_transformer.py new file mode 100644 index 0000000..b816e63 --- /dev/null +++ b/src/NeuroFlex/vision_transformer.py @@ -0,0 +1,53 @@ +import jax +import jax.numpy as jnp +import flax.linen as nn + +class VisionTransformer(nn.Module): + num_classes: int + patch_size: int = 16 + hidden_size: int = 768 + num_heads: int = 12 + num_layers: int = 12 + mlp_dim: int = 3072 + dropout_rate: float = 0.1 + + @nn.compact + def __call__(self, x, train: bool = True): + # Assuming input shape is (batch_size, height, width, channels) + B, H, W, C = x.shape + assert H % self.patch_size == 0 and W % self.patch_size == 0, 'Image dimensions must be divisible by patch size.' + + # Split image into patches + x = jnp.reshape(x, (B, H // self.patch_size, W // self.patch_size, self.patch_size * self.patch_size * C)) + x = jnp.reshape(x, (B, -1, self.patch_size * self.patch_size * C)) + + # Embed patches + x = nn.Dense(self.hidden_size)(x) + + # Add position embeddings + n_patches = x.shape[1] + pos_embed = self.param('pos_embed', nn.initializers.normal(stddev=0.02), (1, n_patches, self.hidden_size)) + x = x + pos_embed + + # Transformer encoder + for _ in range(self.num_layers): + y = nn.LayerNorm()(x) + y = nn.MultiHeadDotProductAttention(num_heads=self.num_heads)(y, y) + y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) + x = x + y + + y = nn.LayerNorm()(x) + y = nn.Dense(self.mlp_dim)(y) + y = nn.gelu(y) + y = nn.Dense(self.hidden_size)(y) + y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) + x = x + y + + # Global average pooling + x = jnp.mean(x, axis=1) + + # Classification head + x = nn.LayerNorm()(x) + x = nn.Dense(self.num_classes)(x) + + return x diff --git a/tests/test_self_curing_algorithm.py b/tests/test_self_curing_algorithm.py new file mode 100644 index 0000000..44ca631 --- /dev/null +++ b/tests/test_self_curing_algorithm.py @@ -0,0 +1,55 @@ +# Test cases for the SelfCuringAlgorithm class + +import unittest +from NeuroFlex.model import SelfCuringAlgorithm, NeuroFlex + +class TestSelfCuringAlgorithm(unittest.TestCase): + def setUp(self): + # Create a mock model with various attributes + self.mock_model = NeuroFlex() + self.mock_model.is_trained = False + self.mock_model.performance = 0.5 + self.mock_model.data_quality = 0.7 + self.self_curing_algorithm = SelfCuringAlgorithm(self.mock_model) + + def test_diagnose_untrained_model(self): + # Test that the diagnose method identifies an untrained model + issues = self.self_curing_algorithm.diagnose() + self.assertIn("Model is not trained", issues) + + def test_heal_untrained_model(self): + # Test that the heal method trains an untrained model + issues = self.self_curing_algorithm.diagnose() + self.self_curing_algorithm.heal(issues) + self.assertTrue(self.mock_model.is_trained) + + def test_diagnose_low_performance(self): + # Test that the diagnose method identifies low performance + self.mock_model.is_trained = True + self.mock_model.performance = 0.3 + issues = self.self_curing_algorithm.diagnose() + self.assertIn("Model performance is low", issues) + + def test_heal_low_performance(self): + # Test that the heal method improves model performance + self.mock_model.is_trained = True + self.mock_model.performance = 0.3 + issues = self.self_curing_algorithm.diagnose() + self.self_curing_algorithm.heal(issues) + self.assertGreater(self.mock_model.performance, 0.3) + + def test_diagnose_data_quality(self): + # Test that the diagnose method identifies poor data quality + self.mock_model.data_quality = 0.4 + issues = self.self_curing_algorithm.diagnose() + self.assertIn("Poor data quality", issues) + + def test_heal_data_quality(self): + # Test that the heal method improves data quality + self.mock_model.data_quality = 0.4 + issues = self.self_curing_algorithm.diagnose() + self.self_curing_algorithm.heal(issues) + self.assertGreater(self.mock_model.data_quality, 0.4) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000..bc2d7d1 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,77 @@ +import unittest +from unittest.mock import Mock, patch +import os +from NeuroFlex.tokenizer import Tokenizer + +class TestTokenizer(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Mock SentencePieceProcessor + cls.mock_sp = Mock() + cls.mock_sp.GetPieceSize.return_value = 1000 + cls.mock_sp.bos_id.return_value = 1 + cls.mock_sp.eos_id.return_value = 2 + cls.mock_sp.pad_id.return_value = 0 + cls.mock_sp.EncodeAsIds.return_value = [10, 20, 30] + cls.mock_sp.DecodeIds.return_value = "Hello, world!" + + # Mock os.path.isfile to return True for the dummy path + cls.mock_isfile = patch('os.path.isfile', return_value=True) + cls.mock_isfile.start() + + # Patch SentencePieceProcessor in Tokenizer + with patch('sentencepiece.SentencePieceProcessor', return_value=cls.mock_sp): + cls.tokenizer = Tokenizer("dummy_path") + + @classmethod + def tearDownClass(cls): + # Stop the os.path.isfile patch + cls.mock_isfile.stop() + + def test_encode_basic(self): + text = "Hello, world!" + tokens = self.tokenizer.encode(text) + self.assertIsInstance(tokens, list) + self.assertGreater(len(tokens), 0) + self.assertEqual(tokens[0], self.tokenizer.bos_id) + + def test_encode_no_bos(self): + text = "Hello, world!" + tokens = self.tokenizer.encode(text, bos=False) + self.assertNotEqual(tokens[0], self.tokenizer.bos_id) + + def test_encode_with_eos(self): + text = "Hello, world!" + tokens = self.tokenizer.encode(text, eos=True) + self.assertEqual(tokens[-1], self.tokenizer.eos_id) + + def test_decode_basic(self): + text = "Hello, world!" + tokens = self.tokenizer.encode(text) + decoded_text = self.tokenizer.decode(tokens) + self.assertIsInstance(decoded_text, str) + self.assertGreater(len(decoded_text), 0) + + def test_encode_decode_roundtrip(self): + original_text = "This is a test sentence." + tokens = self.tokenizer.encode(original_text) + self.mock_sp.DecodeIds.return_value = original_text + decoded_text = self.tokenizer.decode(tokens) + self.assertEqual(original_text, decoded_text) + + def test_empty_string(self): + text = "" + self.mock_sp.EncodeAsIds.return_value = [] + tokens = self.tokenizer.encode(text) + self.assertEqual(len(tokens), 1) # Should only contain BOS token + self.mock_sp.DecodeIds.return_value = "" + decoded_text = self.tokenizer.decode(tokens) + self.assertEqual(decoded_text, "") + + def test_special_tokens(self): + self.assertIsNotNone(self.tokenizer.bos_id) + self.assertIsNotNone(self.tokenizer.eos_id) + self.assertIsNotNone(self.tokenizer.pad_id) + +if __name__ == '__main__': + unittest.main()