diff --git a/src/rez/tests/test_context.py b/src/rez/tests/test_context.py index dc1b4a254..225f08e46 100644 --- a/src/rez/tests/test_context.py +++ b/src/rez/tests/test_context.py @@ -5,16 +5,20 @@ """ test resolved contexts """ +import sys + from rez.tests.util import restore_os_environ, restore_sys_path, TempdirMixin, \ TestBase from rez.resolved_context import get_lock_request, PatchLock, ResolvedContext from rez.bundle_context import bundle_context from rez.bind import hello_world +from rez.solver import SolverCallbackReturn, SolverState from rez.utils.platform_ import platform_ from rez.utils.filesystem import is_subdirectory from rez.utils.formatting import PackageRequest from rez.version import Version import unittest +import unittest.mock import subprocess import platform import shutil @@ -252,6 +256,39 @@ def test_orderer_package_argument(self): resolved = [x.qualified_package_name for x in r.resolved_packages] self.assertEqual(resolved, ['python-2.7.0']) + def test_callback_1(self): + def solver_callback(solver_state: SolverState): + if solver_state.num_fails > 999: + solver_callback_return = SolverCallbackReturn.fail + abort_reason = 'Too many fails' + else: + solver_callback_return = SolverCallbackReturn.keep_going + abort_reason = 'No reason' + return solver_callback_return, abort_reason + + callback = ResolvedContext.Callback(max_fails=1, time_limit=2, callback=solver_callback, buf=sys.stdout) + solve_state = SolverState(3, 1, None) + callback_result = callback(solve_state) + assert callback_result[0] == SolverCallbackReturn.fail + + def test_callback_2(self): + def solver_callback(solver_state: SolverState): + if solver_state.num_fails > 999: + solver_callback_return = SolverCallbackReturn.fail + abort_reason = 'Too many fails' + else: + solver_callback_return = SolverCallbackReturn.keep_going + abort_reason = 'No reason' + return solver_callback_return, abort_reason + + callback = ResolvedContext.Callback(max_fails=999, time_limit=0, callback=solver_callback, buf=sys.stdout) + solve_state = SolverState(3, 1, None) + # Mock start time to be Jan 1, 1970. + with unittest.mock.patch('rez.resolved_context.ResolvedContext.Callback.start_time', + new_callable=unittest.mock.PropertyMock, return_value=0): + callback_result = callback(solve_state) + assert callback_result[0] == SolverCallbackReturn.abort + def test_get_lock_request_1(self): pkg = 'foo' version = Version('1.2.1')