From 308be7108ef3377007d711f50fae958191c3e8db Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Tue, 16 Jul 2024 08:31:31 -0700 Subject: [PATCH] Adds default reads/writes to burr actions This allows you to specify defaults if your action does not write. In the majority of cases they will be none, but this allows simple (static) arbitrary values. This specifically helps with the branching case -- e.g. where you have two options, and want to null out anything it doesn't write. For instance, an error and a result -- you'll only ever produce one or the other. This works both in the function and class-based approaches -- in the function-based it is part of the two decorators (@action/@streaming_action). In the class-based it is part of the class, overriding the default_reads and default_writes property function We add a bunch of new tests for default (as the code to handle multiple action types is fairly dispersed, for now), and also make the naming of the other tests/content more consistent. Note that this does not currently work with settings defaults to append/increment operations -- it will produce strange behavior. This is documented in all appropriate signatures. This also does not work (or even make sense) in the case that the function writes a default that it also reads. In that case, it will clobber the current value with the write value. To avoid this, we just error out if that is the case beforehand. --- burr/core/action.py | 152 +++++++++++++- burr/core/application.py | 82 ++++++-- burr/core/state.py | 33 ++- docs/reference/state.rst | 4 +- tests/core/test_action.py | 80 ++++++- tests/core/test_application.py | 370 +++++++++++++++++++++++++++++---- tests/test_end_to_end.py | 107 +++++++++- 7 files changed, 753 insertions(+), 75 deletions(-) diff --git a/burr/core/action.py b/burr/core/action.py index 5f35cfe6..40b1f0bd 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -6,6 +6,7 @@ import sys import types import typing +from abc import ABC from typing import ( Any, AsyncGenerator, @@ -42,6 +43,16 @@ def reads(self) -> list[str]: """ pass + @property + def default_reads(self) -> Dict[str, Any]: + """Default values to read from state if they are not there already. + This just fills out the gaps in state. This must be a subset + of the ``reads`` value. + + :return: + """ + return {} + @abc.abstractmethod def run(self, state: State, **run_kwargs) -> dict: """Runs the function on the given state and returns the result. @@ -122,18 +133,68 @@ def writes(self) -> list[str]: """ pass + @property + def default_writes(self) -> Dict[str, Any]: + """Default state writes for the reducer. If nothing writes this field from within + the reducer, then this will be written. Note that this is not (currently) + intended to work with append/increment operations. + + This must be a subset of the ``writes`` value. + + :return: A key/value dictionary of default writes. + """ + return {} + @abc.abstractmethod def update(self, result: dict, state: State) -> State: pass -class Action(Function, Reducer, abc.ABC): +class _PostValidator(abc.ABCMeta): + """Metaclass to allow for __post_init__ to be called after __init__. + While this is general we're keeping it here for now as it is only used + by the Action class. This enables us to ensure that the default_reads are correct. + """ + + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if post := getattr(cls, "__post_init__", None): + post(instance) + return instance + + +class Action(Function, Reducer, ABC, metaclass=_PostValidator): def __init__(self): """Represents an action in a state machine. This is the base class from which actions extend. Note that this class needs to have a name set after the fact. """ self._name = None + def __post_init__(self): + self._validate_defaults() + + def _validate_defaults(self): + reads = set(self.reads) + missing_default_reads = {key for key in self.default_reads.keys() if key not in reads} + if missing_default_reads: + raise ValueError( + f"The following default state reads are not in the set of reads for action: {self}: {', '.join(missing_default_reads)}. " + f"Every read in default_reads must be in the reads list." + ) + writes = self.writes + missing_default_writes = {key for key in self.default_writes.keys() if key not in writes} + if missing_default_writes: + raise ValueError( + f"The following default state writes are not in the set of writes for action: {self}: {', '.join(missing_default_writes)}. " + f"Every write in default_writes must be in the writes list." + ) + default_writes_also_in_reads = {key for key in self.default_writes.keys() if key in reads} + if default_writes_also_in_reads: + raise ValueError( + f"The following default state writes are also in the reads for action: {self}: {', '.join(default_writes_also_in_reads)}. " + f"Every write in default_writes must not be in the reads list -- this leads to undefined behavior." + ) + def with_name(self, name: str) -> Self: """Returns a copy of the given action with the given name. Why do we need this? We instantiate actions without names, and then set them later. This is a way to @@ -484,6 +545,8 @@ def __init__( fn: Callable, reads: List[str], writes: List[str], + default_reads: Dict[str, Any] = None, + default_writes: Dict[str, Any] = None, bound_params: dict = None, ): """Instantiates a function-based action with the given function, reads, and writes. @@ -499,11 +562,21 @@ def __init__( self._writes = writes self._bound_params = bound_params if bound_params is not None else {} self._inputs = _get_inputs(self._bound_params, self._fn) + self._default_reads = default_reads if default_reads is not None else {} + self._default_writes = default_writes if default_writes is not None else {} @property def fn(self) -> Callable: return self._fn + @property + def default_reads(self) -> Dict[str, Any]: + return self._default_reads + + @property + def default_writes(self) -> Dict[str, Any]: + return self._default_writes + @property def reads(self) -> list[str]: return self._reads @@ -526,7 +599,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction": :return: """ return FunctionBasedAction( - self._fn, self._reads, self._writes, {**self._bound_params, **kwargs} + self._fn, + self._reads, + self._writes, + self.default_reads, + self._default_writes, + {**self._bound_params, **kwargs}, ) def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]: @@ -918,6 +996,8 @@ def __init__( ], reads: List[str], writes: List[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, bound_params: dict = None, ): """Instantiates a function-based streaming action with the given function, reads, and writes. @@ -931,6 +1011,8 @@ def __init__( self._fn = fn self._reads = reads self._writes = writes + self._default_reads = default_reads if default_reads is not None else {} + self._default_writes = default_writes if default_writes is not None else {} self._bound_params = bound_params if bound_params is not None else {} async def _a_stream_run_and_update( @@ -957,6 +1039,14 @@ def reads(self) -> list[str]: def writes(self) -> list[str]: return self._writes + @property + def default_writes(self) -> Dict[str, Any]: + return self._default_writes + + @property + def default_reads(self) -> Dict[str, Any]: + return self._default_reads + @property def streaming(self) -> bool: return True @@ -969,7 +1059,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction": :return: """ return FunctionBasedStreamingAction( - self._fn, self._reads, self._writes, {**self._bound_params, **kwargs} + self._fn, + self._reads, + self._writes, + self._default_reads, + self._default_writes, + {**self._bound_params, **kwargs}, ) @property @@ -999,7 +1094,10 @@ def bind(self, **kwargs: Any) -> Self: ... -def copy_func(f: types.FunctionType) -> types.FunctionType: +T = TypeVar("T", bound=types.FunctionType) + + +def copy_func(f: T) -> T: """Copies a function. This is used internally to bind parameters to a function so we don't accidentally overwrite them. @@ -1033,7 +1131,12 @@ def my_action(state: State, z: int) -> tuple[dict, State]: return self -def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]: +def action( + reads: List[str], + writes: List[str], + default_reads: Dict[str, Any] = None, + default_writes: Dict[str, Any] = None, +) -> Callable[[Callable], FunctionRepresentingAction]: """Decorator to create a function-based action. This is user-facing. Note that, in the future, with typed state, we may not need this for all cases. @@ -1044,11 +1147,27 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Function :param reads: Items to read from the state :param writes: Items to write to the state + :param default_reads: Default values for reads. If nothing upstream produces these, they will + be filled automatically. This is equivalent to adding + ``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})`` + at the beginning of your function. + :param default_writes: Default values for writes. If the action's state update does not write to this, + they will be filled automatically with the default values. Leaving blank will have no default values. + This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function. + Note that this will not work as intended with append/increment operations, so be careful. :return: The decorator to assign the function as an action """ + default_reads = default_reads if default_reads is not None else {} + default_writes = default_writes if default_writes is not None else {} def decorator(fn) -> FunctionRepresentingAction: - setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes)) + setattr( + fn, + FunctionBasedAction.ACTION_FUNCTION, + FunctionBasedAction( + fn, reads, writes, default_reads=default_reads, default_writes=default_writes + ), + ) setattr(fn, "bind", types.MethodType(bind, fn)) return fn @@ -1056,7 +1175,10 @@ def decorator(fn) -> FunctionRepresentingAction: def streaming_action( - reads: List[str], writes: List[str] + reads: List[str], + writes: List[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, ) -> Callable[[Callable], FunctionRepresentingAction]: """Decorator to create a streaming function-based action. This is user-facing. @@ -1090,14 +1212,28 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State] # return the final result return {'response': full_response}, state.update(response=full_response) + :param reads: Items to read from the state + :param writes: Items to write to the state + :param default_reads: Default values for reads. If nothing upstream produces these, they will + be filled automatically. This is equivalent to adding + ``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})`` + at the beginning of your function. + :param default_writes: Default values for writes. If the action's state update does not write to this, + they will be filled automatically with the default values. Leaving blank will have no default values. + This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function. + Note that this will not work as intended with append/increment operations, so be careful. + :return: The decorator to assign the function as an action + """ + default_reads = default_reads if default_reads is not None else {} + default_writes = default_writes if default_writes is not None else {} def wrapped(fn) -> FunctionRepresentingAction: fn = copy_func(fn) setattr( fn, FunctionBasedAction.ACTION_FUNCTION, - FunctionBasedStreamingAction(fn, reads, writes), + FunctionBasedStreamingAction(fn, reads, writes, default_reads, default_writes), ) setattr(fn, "bind", types.MethodType(bind, fn)) return fn diff --git a/burr/core/application.py b/burr/core/application.py index 69591f00..949db96c 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -83,6 +83,33 @@ def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_ _raise_fn_return_validation_error(output, action_name) +def _pre_apply_read_defaults( + state: State, + default_reads: Dict[str, Any], +): + """Applies default values to the state prior to execution. + This just applies them to the state so the action can overwrite them. + """ + state_update = {} + for key, value in default_reads.items(): + if key not in state: + state_update[key] = value + return state.update(**state_update) + + +def _pre_apply_write_defaults( + state: State, + default_writes: Dict[str, Any], +) -> State: + """Applies default values to the state prior to execution. + This just applies them to the state so the action can overwrite them. + """ + state_update = {} + for key, value in default_writes.items(): + state_update[key] = value + return state.update(**state_update) + + def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict: """Runs a function, returning the result of running the function. Note this restricts the keys in the state to only those that the @@ -100,6 +127,7 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name "instead...)" ) state_to_use = state.subset(*function.reads) + state_to_use = _pre_apply_read_defaults(state_to_use, function.default_reads) function.validate_inputs(inputs) result = function.run(state_to_use, **inputs) _validate_result(result, name) @@ -112,6 +140,7 @@ async def _arun_function( """Runs a function, returning the result of running the function. Async version of the above.""" state_to_use = state.subset(*function.reads) + state_to_use = _pre_apply_read_defaults(state_to_use, function.default_reads) function.validate_inputs(inputs) result = await function.run(state_to_use, **inputs) _validate_result(result, name) @@ -168,7 +197,8 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta :return: """ # TODO -- better guarding on state reads/writes - new_state = reducer.update(result, state) + state_with_defaults = _pre_apply_write_defaults(state, reducer.default_writes) + new_state = reducer.update(result, state_with_defaults) keys_in_new_state = set(new_state.keys()) new_keys = keys_in_new_state - set(state.keys()) extra_keys = new_keys - set(reducer.writes) @@ -216,6 +246,22 @@ def _format_BASE_ERROR_MESSAGE(action: Action, input_state: State, inputs: dict) return "\n" + border + "\n" + message + "\n" + border +def _prep_state_single_step_action(state: State, action: SingleStepAction): + """Runs default application for single step action. + First applies read defaults, then applies write defaults. Note write defaults + will blogger read + + :param state: + :param action: + :return: + """ + # first apply read defaults + state = _pre_apply_read_defaults(state, action.default_reads) + # then apply write defaults so if the action doesn't write it will be in state + state = _pre_apply_write_defaults(state, action.default_writes) + return state + + def _run_single_step_action( action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] ) -> Tuple[Dict[str, Any], State]: @@ -229,6 +275,7 @@ def _run_single_step_action( """ # TODO -- guard all reads/writes with a subset of the state action.validate_inputs(inputs) + state = _prep_state_single_step_action(state, action) result, new_state = _adjust_single_step_output( action.run_and_update(state, **inputs), action.name ) @@ -245,6 +292,7 @@ def _run_single_step_streaming_action( """Runs a single step streaming action. This API is internal-facing. This normalizes + validates the output.""" action.validate_inputs(inputs) + state = _prep_state_single_step_action(state, action) generator = action.stream_run_and_update(state, **inputs) result = None state_update = None @@ -269,11 +317,26 @@ def _run_single_step_streaming_action( yield result, state_update +async def _arun_single_step_action( + action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] +) -> Tuple[dict, State]: + """Runs a single step action in async. See the synchronous version for more details.""" + state_to_use = _prep_state_single_step_action(state, action) + action.validate_inputs(inputs) + result, new_state = _adjust_single_step_output( + await action.run_and_update(state_to_use, **inputs), action.name + ) + _validate_result(result, action.name) + _validate_reducer_writes(action, new_state, action.name) + return result, _state_update(state, new_state) + + async def _arun_single_step_streaming_action( action: SingleStepStreamingAction, state: State, inputs: Optional[Dict[str, Any]] ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: """Runs a single step streaming action in async. See the synchronous version for more details.""" action.validate_inputs(inputs) + state = _prep_state_single_step_action(state, action) generator = action.stream_run_and_update(state, **inputs) result = None state_update = None @@ -310,6 +373,7 @@ def _run_multi_step_streaming_action( This peeks ahead by one so we know when this is done (and when to validate). """ action.validate_inputs(inputs) + state = _pre_apply_read_defaults(state, action.default_reads) generator = action.stream_run(state, **inputs) result = None for item in generator: @@ -331,6 +395,7 @@ async def _arun_multi_step_streaming_action( ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: """Runs a multi-step streaming action in async. See the synchronous version for more details.""" action.validate_inputs(inputs) + state = _pre_apply_read_defaults(state, action.default_reads) generator = action.stream_run(state, **inputs) result = None async for item in generator: @@ -347,20 +412,6 @@ async def _arun_multi_step_streaming_action( yield result, state_update -async def _arun_single_step_action( - action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] -) -> Tuple[dict, State]: - """Runs a single step action in async. See the synchronous version for more details.""" - state_to_use = state - action.validate_inputs(inputs) - result, new_state = _adjust_single_step_output( - await action.run_and_update(state_to_use, **inputs), action.name - ) - _validate_result(result, action.name) - _validate_reducer_writes(action, new_state, action.name) - return result, _state_update(state, new_state) - - @dataclasses.dataclass class ApplicationGraph(Graph): """User-facing representation of the state machine. This has @@ -639,7 +690,6 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d return out async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True): - # we want to increment regardless of failure with self.context: next_action = self.get_next_action() if next_action is None: diff --git a/burr/core/state.py b/burr/core/state.py index 62c9a583..49a6fb3a 100644 --- a/burr/core/state.py +++ b/burr/core/state.py @@ -95,6 +95,11 @@ def writes(self) -> list[str]: """Returns the keys that this state delta writes""" pass + @abc.abstractmethod + def deletes(self) -> list[str]: + """Returns the keys that this state delta deletes""" + pass + @abc.abstractmethod def apply_mutate(self, inputs: dict): """Applies the state delta to the inputs""" @@ -117,6 +122,9 @@ def reads(self) -> list[str]: def writes(self) -> list[str]: return list(self.values.keys()) + def deletes(self) -> list[str]: + return [] + def apply_mutate(self, inputs: dict): inputs.update(self.values) @@ -137,13 +145,21 @@ def reads(self) -> list[str]: def writes(self) -> list[str]: return list(self.values.keys()) + def deletes(self) -> list[str]: + return [] + def apply_mutate(self, inputs: dict): for key, value in self.values.items(): if key not in inputs: inputs[key] = [] if not isinstance(inputs[key], list): raise ValueError(f"Cannot append to non-list value {key}={inputs[self.key]}") - inputs[key].append(value) + inputs[key] = [ + *inputs[key], + value, + ] # Not as efficient but safer, so we don't mutate the original list + # we're doing this to avoid a copy.deepcopy() call, so it is already more efficient than it was before + # That said, if one modifies prior values in the list, it is on them, and undefined behavior def validate(self, input_state: Dict[str, Any]): incorrect_types = {} @@ -171,6 +187,9 @@ def reads(self) -> list[str]: def writes(self) -> list[str]: return list(self.values.keys()) + def deletes(self) -> list[str]: + return [] + def validate(self, input_state: Dict[str, Any]): incorrect_types = {} for write_key in self.writes(): @@ -201,11 +220,14 @@ def name(cls) -> str: return "delete" def reads(self) -> list[str]: - return list(self.keys) + return [] def writes(self) -> list[str]: return [] + def deletes(self) -> list[str]: + return list(self.keys) + def apply_mutate(self, inputs: dict): for key in self.keys: inputs.pop(key, None) @@ -221,11 +243,12 @@ def __init__(self, initial_values: Dict[str, Any] = None): def apply_operation(self, operation: StateDelta) -> "State": """Applies a given operation to the state, returning a new state""" - new_state = copy.deepcopy(self._state) # TODO -- restrict to just the read keys + new_state = copy.copy(self._state) # TODO -- restrict to just the read keys operation.validate(new_state) operation.apply_mutate( new_state ) # todo -- validate that the write keys are the only different ones + # we want to carry this on for now return State(new_state) def get_all(self) -> Dict[str, Any]: @@ -331,7 +354,9 @@ def merge(self, other: "State") -> "State": def subset(self, *keys: str, ignore_missing: bool = True) -> "State": """Returns a subset of the state, with only the given keys""" - return State({key: self[key] for key in keys if key in self or not ignore_missing}) + return State( + {key: self[key] for key in keys if key in self or not ignore_missing}, + ) def __getitem__(self, __k: str) -> Any: return self._state[__k] diff --git a/docs/reference/state.rst b/docs/reference/state.rst index 4e5c3e06..d999bd6f 100644 --- a/docs/reference/state.rst +++ b/docs/reference/state.rst @@ -1,6 +1,6 @@ -================= +===== State -================= +===== Use the state API to manipulate the state of the application. diff --git a/tests/core/test_action.py b/tests/core/test_action.py index b623b598..70a49a63 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncGenerator, Generator, Optional, Tuple, cast +from typing import Any, AsyncGenerator, Dict, Generator, Optional, Tuple, cast import pytest @@ -719,3 +719,81 @@ async def callback(r: Optional[dict], s: State, e: Exception): ((result, state, error),) = called assert state["foo"] == "bar" assert result is None + + +def test_action_subclass_validate_defaults_fails_incorrect_writes(): + """Tests that the initialization of subclass hook validates as intended""" + + class IncorrectReads(Action): + @property + def reads(self) -> list[str]: + return ["foo"] + + @property + def writes(self) -> list[str]: + return ["bar"] + + @property + def default_reads(self) -> Dict[str, Any]: + return {"qux": None} + + def run(self, state: State) -> dict: + raise ValueError("This should never be called") + + def update(self, result: dict, state: State) -> State: + raise ValueError("This should never be called") + + with pytest.raises(ValueError): + IncorrectReads() + + +def test_action_subclass_validate_defaults_fails_incorrect_reads(): + """Tests that the initialiation of subclass hook validates as intended""" + + class IncorrectWrites(Action): + @property + def reads(self) -> list[str]: + return ["foo"] + + @property + def writes(self) -> list[str]: + return ["bar"] + + @property + def default_writes(self) -> Dict[str, Any]: + return {"qux": None} + + def run(self, state: State) -> dict: + raise ValueError("This should never be called") + + def update(self, result: dict, state: State) -> State: + raise ValueError("This should never be called") + + with pytest.raises(ValueError): + IncorrectWrites() + + +def test_action_subclass_validate_defaults_fails_reads_in_default_writes(): + """Tests that the initialiation of subclass hook validates as intended""" + + class IncorrectWrites(Action): + @property + def reads(self) -> list[str]: + return ["foo"] + + @property + def writes(self) -> list[str]: + return ["bar"] + + @property + def default_writes(self) -> Dict[str, Any]: + return {"foo": None} + + def run(self, state: State) -> dict: + raise ValueError("This should never be called") + + def update(self, result: dict, state: State) -> State: + raise ValueError("This should never be called") + + with pytest.raises(ValueError): + IncorrectWrites() diff --git a/tests/core/test_application.py b/tests/core/test_application.py index d6c4257a..bf77a7b0 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -31,6 +31,8 @@ _arun_multi_step_streaming_action, _arun_single_step_action, _arun_single_step_streaming_action, + _pre_apply_read_defaults, + _pre_apply_write_defaults, _run_function, _run_multi_step_streaming_action, _run_reducer, @@ -60,6 +62,8 @@ def __init__( fn: Callable[..., dict], update_fn: Callable[[dict, State], State], inputs: list[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, ): super(PassedInAction, self).__init__() self._reads = reads @@ -67,6 +71,8 @@ def __init__( self._fn = fn self._update_fn = update_fn self._inputs = inputs + self._default_reads = default_reads if default_reads is not None else {} + self._default_writes = default_writes if default_writes is not None else {} def run(self, state: State, **run_kwargs) -> dict: return self._fn(state, **run_kwargs) @@ -75,6 +81,14 @@ def run(self, state: State, **run_kwargs) -> dict: def inputs(self) -> list[str]: return self._inputs + @property + def default_reads(self) -> Dict[str, Any]: + return self._default_reads + + @property + def default_writes(self) -> Dict[str, Any]: + return self._default_writes + def update(self, result: dict, state: State) -> State: return self._update_fn(result, state) @@ -95,8 +109,18 @@ def __init__( fn: Callable[..., Awaitable[dict]], update_fn: Callable[[dict, State], State], inputs: list[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, ): - super().__init__(reads=reads, writes=writes, fn=fn, update_fn=update_fn, inputs=inputs) # type: ignore + super().__init__( + reads=reads, + writes=writes, + fn=fn, + update_fn=update_fn, + inputs=inputs, + default_reads=default_reads, + default_writes=default_writes, + ) # type: ignore async def run(self, state: State, **run_kwargs) -> dict: return await self._fn(state, **run_kwargs) @@ -105,7 +129,7 @@ async def run(self, state: State, **run_kwargs) -> dict: base_counter_action = PassedInAction( reads=["count"], writes=["count"], - fn=lambda state: {"count": state.get("count", 0) + 1}, + fn=lambda state: {"count": state["count"] + 1}, update_fn=lambda result, state: state.update(**result), inputs=[], ) @@ -113,13 +137,21 @@ async def run(self, state: State, **run_kwargs) -> dict: base_counter_action_with_inputs = PassedInAction( reads=["count"], writes=["count"], - fn=lambda state, additional_increment: { - "count": state.get("count", 0) + 1 + additional_increment - }, + fn=lambda state, additional_increment: {"count": state["count"] + 1 + additional_increment}, update_fn=lambda result, state: state.update(**result), inputs=["additional_increment"], ) +base_counter_action_with_defaults = PassedInAction( + reads=["count"], + writes=["count", "error"], + fn=lambda state: {"count": state["count"] + 1}, + update_fn=lambda result, state: state.update(**result), + inputs=[], + default_reads={"count": 0}, + default_writes={"error": None}, +) + class ActionTracker(PreRunStepHook, PostRunStepHook): def __init__(self): @@ -172,7 +204,7 @@ async def post_run_step(self, *, action: Action, **future_kwargs): async def _counter_update_async(state: State, additional_increment: int = 0) -> dict: await asyncio.sleep(0.0001) # just so we can make this *truly* async # does not matter, but more accurately simulates an async function - return {"count": state.get("count", 0) + 1 + additional_increment} + return {"count": state["count"] + 1 + additional_increment} base_counter_action_async = PassedInActionAsync( @@ -193,6 +225,16 @@ async def _counter_update_async(state: State, additional_increment: int = 0) -> inputs=["additional_increment"], ) +base_counter_action_async_with_defaults = PassedInActionAsync( + reads=["count"], + writes=["count", "error"], + fn=_counter_update_async, + update_fn=lambda result, state: state.update(**result), + inputs=[], + default_reads={"count": 0}, + default_writes={"error": None}, +) + class BrokenStepException(Exception): pass @@ -236,18 +278,46 @@ async def incorrect(x): ) +def test__pre_read_apply_defaults(): + state = State({"in_state": 0}) + defaults = {"in_state": 1, "not_in_state": 2} + result = _pre_apply_read_defaults(state, defaults) + assert result["in_state"] == 0 + assert result["not_in_state"] == 2 + + +def test__pre_write_apply_defaults(): + # Write defaults should always be applied, + # they'll be overwritten by the reducer + state = State({"in_state": 0, "to_be_overwritten": 0}) + defaults = {"in_state": 1, "not_in_state": 2, "to_be_overwritten": 3} + result = _pre_apply_write_defaults(state, defaults) + assert result["in_state"] == 1 + assert result["not_in_state"] == 2 + assert result["to_be_overwritten"] == 3 + + def test__run_function(): """Tests that we can run a function""" action = base_counter_action - state = State({}) + state = State({"count": 0}) result = _run_function(action, state, inputs={}, name=action.name) assert result == {"count": 1} +def test__run_function_defaults(): + action = base_counter_action_with_defaults + state = State({}) + result = _run_function(action, state, inputs={}, name=action.name) + assert result == { + "count": 1 + } # default read is applied, write is not as it is a reducer capability, and is not part of the result + + def test__run_function_with_inputs(): """Tests that we can run a function""" action = base_counter_action_with_inputs - state = State({}) + state = State({"count": 0}) result = _run_function(action, state, inputs={"additional_increment": 1}, name=action.name) assert result == {"count": 2} @@ -255,7 +325,7 @@ def test__run_function_with_inputs(): def test__run_function_cant_run_async(): """Tests that we can't run an async function""" action = base_counter_action_async - state = State({}) + state = State({"count": 0}) with pytest.raises(ValueError, match="async"): _run_function(action, state, inputs={}, name=action.name) @@ -263,11 +333,19 @@ def test__run_function_cant_run_async(): def test__run_function_incorrect_result_type(): """Tests that we can run an async function""" action = base_action_incorrect_result_type - state = State({}) + state = State({"count": 0}) with pytest.raises(ValueError, match="returned a non-dict"): _run_function(action, state, inputs={}, name=action.name) +def test__run_reducer_applies_defaults(): + """Tests that we can run a reducer and it behaves as expected""" + reducer = base_counter_action_with_defaults + state = State({"count": 0}) + state = _run_reducer(reducer, state, {"count": 1}, "reducer") + assert state.get_all() == {"count": 1, "error": None} + + def test__run_reducer_modifies_state(): """Tests that we can run a reducer and it behaves as expected""" reducer = PassedInAction( @@ -299,6 +377,14 @@ def test__run_reducer_deletes_state(): async def test__arun_function(): """Tests that we can run an async function""" action = base_counter_action_async + state = State({"count": 0}) + result = await _arun_function(action, state, inputs={}, name=action.name) + assert result == {"count": 1} + + +async def test__arun_function_with_defaults(): + """Tests that we can run an async function""" + action = base_counter_action_async_with_defaults state = State({}) result = await _arun_function(action, state, inputs={}, name=action.name) assert result == {"count": 1} @@ -315,7 +401,7 @@ async def test__arun_function_incorrect_result_type(): async def test__arun_function_with_inputs(): """Tests that we can run an async function""" action = base_counter_action_with_inputs_async - state = State({}) + state = State({"count": 0}) result = await _arun_function( action, state, inputs={"additional_increment": 1}, name=action.name ) @@ -447,6 +533,9 @@ def writes(self) -> list[str]: class SingleStepCounter(SingleStepAction): + def __init__(self): + super(SingleStepCounter, self).__init__() + def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: result = {"count": state["count"] + 1 + sum([0] + list(run_kwargs.values()))} return result, state.update(**result).append(tracker=result["count"]) @@ -466,6 +555,23 @@ def inputs(self) -> list[str]: return ["additional_increment"] +class SingleStepCounterWithDefaults(SingleStepCounter): + def __init__(self): + super(SingleStepCounterWithDefaults, self).__init__() + + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepActionIncorrectResultType(SingleStepAction): def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return "not a dict", state @@ -498,6 +604,20 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepCounterWithDefaultsAsync(SingleStepCounterAsync): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepCounterWithInputsAsync(SingleStepCounterAsync): @property def inputs(self) -> list[str]: @@ -527,6 +647,20 @@ def update(self, result: dict, state: State) -> State: return state.update(**result).append(tracker=result["count"]) +class StreamingCounterWithDefaults(StreamingCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class AsyncStreamingCounter(AsyncStreamingAction): async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: if "steps_per_count" in run_kwargs: @@ -552,6 +686,20 @@ def update(self, result: dict, state: State) -> State: return state.update(**result).append(tracker=result["count"]) +class StreamingCounterWithDefaultsAsync(AsyncStreamingCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepStreamingCounter(SingleStepStreamingAction): def stream_run_and_update( self, state: State, **run_kwargs @@ -571,6 +719,20 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepStreamingCounterWithDefaults(SingleStepStreamingCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepStreamingCounterAsync(SingleStepStreamingAction): async def stream_run_and_update( self, state: State, **run_kwargs @@ -592,6 +754,20 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepStreamingCounterWithDefaultsAsync(SingleStepStreamingCounterAsync): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class StreamingActionIncorrectResultType(StreamingAction): def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]: yield {} @@ -662,12 +838,20 @@ def writes(self) -> list[str]: base_single_step_counter_async = SingleStepCounterAsync() base_single_step_counter_with_inputs = SingleStepCounterWithInputs() base_single_step_counter_with_inputs_async = SingleStepCounterWithInputsAsync() +base_single_step_counter_with_defaults = SingleStepCounterWithDefaults() +base_single_step_counter_with_defaults_async = SingleStepCounterWithDefaultsAsync() base_streaming_counter = StreamingCounter() +base_streaming_counter_with_defaults = StreamingCounterWithDefaults() base_streaming_single_step_counter = SingleStepStreamingCounter() +base_streaming_single_step_counter_with_defaults = SingleStepStreamingCounterWithDefaults() base_streaming_counter_async = AsyncStreamingCounter() +base_streaming_counter_with_defaults_async = StreamingCounterWithDefaultsAsync() base_streaming_single_step_counter_async = SingleStepStreamingCounterAsync() +base_streaming_single_step_counter_with_defaults_async = ( + SingleStepStreamingCounterWithDefaultsAsync() +) base_single_step_action_incorrect_result_type = SingleStepActionIncorrectResultType() base_single_step_action_incorrect_result_type_async = SingleStepActionIncorrectResultTypeAsync() @@ -709,6 +893,18 @@ def test__run_single_step_action_with_inputs(): assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} +def test__run_single_step_action_with_defaults(): + action = base_single_step_counter_with_defaults.with_name("counter") + state = State({}) + result, state = _run_single_step_action(action, state, {}) + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + async def test__arun_single_step_action(): action = base_single_step_counter_async.with_name("counter") state = State({"count": 0, "tracker": []}) @@ -735,6 +931,18 @@ async def test__arun_single_step_action_with_inputs(): assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} +async def test__arun_single_step_action_with_defaults(): + action = base_single_step_counter_with_defaults_async.with_name("counter") + state = State({}) + result, state = await _arun_single_step_action(action, state, {}) + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + class SingleStepActionWithDeletion(SingleStepAction): def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"]) @@ -770,7 +978,26 @@ def test__run_multistep_streaming_action(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} -async def test__run_multistep_streaming_action_async(): +def test__run_multistep_streaming_action_default(): + action = base_streaming_counter_with_defaults.with_name("counter") + state = State({}) + generator = _run_multi_step_streaming_action(action, state, inputs={}) + last_result = -1 + result = None + for result, state in generator: + if last_result < 1: + # Otherwise you hit floating poit comparison problems + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + +async def test__arun_multistep_streaming_action(): action = base_streaming_counter_async.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_multi_step_streaming_action(action, state, inputs={}) @@ -785,6 +1012,25 @@ async def test__run_multistep_streaming_action_async(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} +async def test__arun_multistep_streaming_action_with_defaults(): + action = base_streaming_counter_with_defaults_async.with_name("counter") + state = State({}) + generator = _arun_multi_step_streaming_action(action, state, inputs={}) + last_result = -1 + result = None + async for result, state in generator: + if last_result < 1: + # Otherwise you hit floating poit comparison problems + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + def test__run_streaming_action_incorrect_result_type(): action = StreamingActionIncorrectResultType() state = State() @@ -793,7 +1039,7 @@ def test__run_streaming_action_incorrect_result_type(): collections.deque(gen, maxlen=0) # exhaust the generator -async def test__run_streaming_action_incorrect_result_type_async(): +async def test__arun_streaming_action_incorrect_result_type(): action = StreamingActionIncorrectResultTypeAsync() state = State() with pytest.raises(ValueError, match="returned a non-dict"): @@ -810,7 +1056,7 @@ def test__run_single_step_streaming_action_incorrect_result_type(): collections.deque(gen, maxlen=0) # exhaust the generator -async def test__run_single_step_streaming_action_incorrect_result_type_async(): +async def test__arun_single_step_streaming_action_incorrect_result_type(): action = StreamingSingleStepActionIncorrectResultTypeAsync() state = State() with pytest.raises(ValueError, match="returned a non-dict"): @@ -834,7 +1080,27 @@ def test__run_single_step_streaming_action(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} -async def test__run_single_step_streaming_action_async(): +def test__run_single_step_streaming_with_defaults(): + action = base_streaming_single_step_counter_with_defaults.with_name("counter") + state = State() + generator = _run_single_step_streaming_action(action, state, inputs={}) + last_result = -1 + result, state = None, None + for result, state in generator: + if last_result < 1: + # Otherwise you hit comparison issues + # This is because we get to the last one, which is the final result + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + +async def test__arun_single_step_streaming_action(): async_action = base_streaming_single_step_counter_async.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_single_step_streaming_action(async_action, state, inputs={}) @@ -850,6 +1116,26 @@ async def test__run_single_step_streaming_action_async(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} +async def test__arun_single_step_streaming_action_with_defaults(): + async_action = base_streaming_single_step_counter_with_defaults_async.with_name("counter") + state = State({}) + generator = _arun_single_step_streaming_action(async_action, state, inputs={}) + last_result = -1 + result, state = None, None + async for result, state in generator: + if last_result < 1: + # Otherwise you hit comparison issues + # This is because we get to the last one, which is the final result + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"]) @@ -866,7 +1152,7 @@ def test_app_step(): """Tests that we can run a step in an app""" counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -945,7 +1231,7 @@ def test_app_step_done(): """Tests that when we cannot run a step, we return None""" counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -963,7 +1249,7 @@ async def test_app_astep(): """Tests that we can run an async step in an app""" counter_action = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_async", partition_key="test", uid="test-123", @@ -1022,7 +1308,7 @@ async def test_app_astep_broken(caplog): """Tests that we can run a step in an app""" broken_action = base_broken_action_async.with_name("broken_action_unique_name") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="broken_action_unique_name", partition_key="test", uid="test-123", @@ -1042,7 +1328,7 @@ async def test_app_astep_done(): """Tests that when we cannot run a step, we return None""" counter_action = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_async", partition_key="test", uid="test-123", @@ -1060,7 +1346,7 @@ async def test_app_astep_done(): def test_app_many_steps(): counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1080,7 +1366,7 @@ def test_app_many_steps(): async def test_app_many_a_steps(): counter_action = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_async", partition_key="test", uid="test-123", @@ -1101,7 +1387,7 @@ def test_iterate(): result_action = Result("count").with_name("result") counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1141,7 +1427,7 @@ def test_iterate_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1170,7 +1456,7 @@ async def test_aiterate(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1202,7 +1488,7 @@ async def test_aiterate_halt_before(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1232,7 +1518,7 @@ async def test_app_aiterate_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1257,7 +1543,7 @@ def test_run(): result_action = Result("count").with_name("result") counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1279,7 +1565,7 @@ def test_run_halt_before(): result_action = Result("count").with_name("result") counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1302,7 +1588,7 @@ def test_run_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1326,7 +1612,7 @@ def test_run_with_inputs_multiple_actions(): counter_action1 = base_counter_action_with_inputs.with_name("counter1") counter_action2 = base_counter_action_with_inputs.with_name("counter2") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter1", partition_key="test", uid="test-123", @@ -1351,7 +1637,7 @@ async def test_arun(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1373,7 +1659,7 @@ async def test_arun_halt_before(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1396,7 +1682,7 @@ async def test_arun_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1421,7 +1707,7 @@ async def test_arun_with_inputs_multiple_actions(): counter_action1 = base_counter_action_with_inputs_async.with_name("counter1") counter_action2 = base_counter_action_with_inputs_async.with_name("counter2") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter1", partition_key="test", uid="test-123", @@ -1449,7 +1735,7 @@ async def test_app_a_run_async_and_sync(): counter_action_sync = base_counter_action_async.with_name("counter_sync") counter_action_async = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_sync", partition_key="test", uid="test-123", @@ -1878,7 +2164,7 @@ async def test_astream_result_halt_before(): def test_app_set_state(): counter_action = base_counter_action.with_name("counter") app = Application( - state=State(), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1901,7 +2187,7 @@ def test_app_get_next_step(): counter_action_2 = base_counter_action.with_name("counter_2") counter_action_3 = base_counter_action.with_name("counter_3") app = Application( - state=State(), + state=State({"count": 0}), entrypoint="counter_1", partition_key="test", uid="test-123", @@ -1988,7 +2274,7 @@ def test_application_run_step_hooks_sync(): counter_action = base_counter_action.with_name("counter") result_action = Result("count").with_name("result") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(tracker), partition_key="test", @@ -2035,7 +2321,7 @@ async def test_application_run_step_hooks_async(): counter_action = base_counter_action.with_name("counter") result_action = Result("count").with_name("result") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(tracker), partition_key="test", @@ -2079,7 +2365,7 @@ async def test_application_run_step_runs_hooks(): counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(*hooks), partition_key="test", @@ -2147,7 +2433,7 @@ def post_application_create(self, **kwargs): counter_action = base_counter_action.with_name("counter") result_action = Result("count").with_name("result") Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(tracker), partition_key="test", diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 4d15730f..9f94d89b 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -3,11 +3,12 @@ see failures in these tests, you should make a unit test, demonstrate the failure there, then fix both in that test and the end-to-end test.""" from io import StringIO -from typing import Any, Tuple +from typing import Any, AsyncGenerator, Generator, Tuple from unittest.mock import patch from burr.core import Action, ApplicationBuilder, State, action -from burr.core.action import Input, Result, expr +from burr.core.action import Input, Result, expr, streaming_action +from burr.core.graph import GraphBuilder from burr.lifecycle import base @@ -89,3 +90,105 @@ def echo(state: State) -> Tuple[dict, State]: format="png", ) assert result["response"] == prompt + + +def test_action_end_to_end_streaming_with_defaults(): + @streaming_action( + reads=["count"], + writes=["done", "error"], + default_reads={"count": 10}, + default_writes={"done": False, "error": None}, + ) + def echo( + state: State, should_error: bool, letter_to_repeat: str + ) -> Generator[Tuple[dict, State], None, None]: + for i in range(state["count"]): + yield {"letter_to_repeat": letter_to_repeat}, None + if should_error: + yield {"error": "Error"}, state.update(error="Error") + else: + yield {"done": True}, state.update(done=True) + + graph = ( + GraphBuilder() + .with_actions( + echo_success=echo.bind(should_error=False), + echo_failure=echo.bind(should_error=True), + ) + .with_transitions( + ("echo_success", "echo_failure"), + ("echo_failure", "echo_success"), + ) + .build() + ) + app = ApplicationBuilder().with_graph(graph).with_entrypoint("echo_success").build() + action_completed, streaming_container = app.stream_result( + halt_after=["echo_success"], inputs={"letter_to_repeat": "a"} + ) + for item in streaming_container: + assert item == {"letter_to_repeat": "a"} + result, state = streaming_container.get() + assert result == {"done": True} + assert state["done"] is True + assert state["error"] is None # default + + action_completed, streaming_container = app.stream_result( + halt_after=["echo_failure"], inputs={"letter_to_repeat": "a"} + ) + for item in streaming_container: + assert item == {"letter_to_repeat": "a"} + result, state = streaming_container.get() + assert result == {"error": "Error"} + assert state["done"] is False + assert state["error"] == "Error" + + +async def test_action_end_to_end_streaming_with_defaults_async(): + @streaming_action( + reads=["count"], + writes=["done", "error"], + default_reads={"count": 10}, + default_writes={"done": False, "error": None}, + ) + async def echo( + state: State, should_error: bool, letter_to_repeat: str + ) -> AsyncGenerator[Tuple[dict, State], None]: + for i in range(state["count"]): + yield {"letter_to_repeat": letter_to_repeat}, None + if should_error: + yield {"error": "Error"}, state.update(error="Error") + else: + yield {"done": True}, state.update(done=True) + + graph = ( + GraphBuilder() + .with_actions( + echo_success=echo.bind(should_error=False), + echo_failure=echo.bind(should_error=True), + ) + .with_transitions( + ("echo_success", "echo_failure"), + ("echo_failure", "echo_success"), + ) + .build() + ) + app = ApplicationBuilder().with_graph(graph).with_entrypoint("echo_success").build() + action_completed, streaming_container = await app.astream_result( + halt_after=["echo_success"], inputs={"letter_to_repeat": "a"} + ) + async for item in streaming_container: + assert item == {"letter_to_repeat": "a"} + result, state = await streaming_container.get() + assert result == {"done": True} + assert state["done"] is True + assert state["error"] is None # default + + action_completed, streaming_container = await app.astream_result( + halt_after=["echo_failure"], inputs={"letter_to_repeat": "a"} + ) + async for item in streaming_container: + assert item == {"letter_to_repeat": "a"} + result, state = await streaming_container.get() + assert result == {"error": "Error"} + assert state["done"] is False + assert state["error"] == "Error"