diff --git a/patsy/design_info.py b/patsy/design_info.py index 2f98968..7cd74a0 100644 --- a/patsy/design_info.py +++ b/patsy/design_info.py @@ -685,7 +685,7 @@ def var_names(self, eval_env=0): else: return {} - def partial(self, columns, product=False): + def partial(self, columns, product=False, eval_env=0): """Returns a partial prediction array where only the variables in the dict ``columns`` are tranformed per the :class:`DesignInfo` transformations. The terms that are not influenced by ``columns`` @@ -703,6 +703,18 @@ def partial(self, columns, product=False): :returns: A numpy array of the partial design matrix. """ from .highlevel import dmatrix + from types import ModuleType + + if not eval_env: + from patsy.eval import EvalEnvironment + eval_env = EvalEnvironment.capture(eval_env, reference=1) + + # We need to get rid of the non-callable items from the eval_env + namespaces = [{key: value} for ns in eval_env._namespaces + for key, value in six.iteritems(ns) + if callable(value) or isinstance(value, ModuleType)] + eval_env._namespaces = namespaces + if product: columns = _column_product(columns) rows = None @@ -712,7 +724,7 @@ def partial(self, columns, product=False): rows = len(columns[col]) parts = [] for term, subterm in six.iteritems(self.term_codings): - term_vars = term.var_names() + term_vars = term.var_names(eval_env) present = True for term_var in term_vars: if term_var not in columns: @@ -1312,6 +1324,16 @@ def test_DesignInfo_partial(): assert_raises(ValueError, dm.design_info.partial, {'a': ['a', 'b'], 'b': [1, 2, 3]}) + def some_function(x): + return np.where(x > 2, 1, 2) + + dm = dmatrix('1 + some_function(c)') + x = np.array([[0, 2], + [0, 2], + [0, 1]]) + y = dm.design_info.partial({'c': np.array([1, 2, 3])}) + assert_allclose(x, y) + def _column_product(columns): from itertools import product