Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLVWrapper #1377

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 111 additions & 4 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
msg = "This module requires mlflow. Install using `pip install mlflow`"
raise ImportError(msg)

import mlflow.artifacts
from mlflow.pyfunc.model import PythonModel
from mlflow.utils.autologging_utils import autologging_integration

from pymc_marketing.clv.models.basic import CLVModel
Expand Down Expand Up @@ -471,7 +473,7 @@
mlflow.log_metric(f"{metric}_{stat.replace('%', '')}", value)


class MMMWrapper(mlflow.pyfunc.PythonModel):
class MMMWrapper(PythonModel):
"""A class to prepare a PyMC Marketing Mix Model (MMM) for logging and registering in MLflow.

This class extends MLflow's PythonModel to handle prediction tasks using a PyMC-based MMM.
Expand Down Expand Up @@ -706,8 +708,47 @@
mlflow.register_model(model_uri, registered_model_name)


def load_mmm(
class CLVWrapper(PythonModel):
"""Wrapper class for logging with MLflow."""

def __init__(self, model: CLVModel, method: str):
self.model = model
self.method = method

Check warning on line 716 in pymc_marketing/mlflow.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mlflow.py#L715-L716

Added lines #L715 - L716 were not covered by tests

def predict(
self,
context: Any,
model_input,
params: dict[str, Any] | None = None,
) -> Any:
"""Perform predictions or sampling using the specified prediction method."""
return getattr(self.model, self.method)(model_input, **params)

Check warning on line 725 in pymc_marketing/mlflow.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mlflow.py#L725

Added line #L725 was not covered by tests


def log_clv(
model: CLVModel,
method: str,
artifact_path: str = "model",
registered_model_name: str | None = None,
) -> None:
"""Log a PyMC-Marketing CLV model as a native MLflow model for the current run."""
mlflow_clv = CLVWrapper(model=model, method=method)

Check warning on line 735 in pymc_marketing/mlflow.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mlflow.py#L735

Added line #L735 was not covered by tests

mlflow.pyfunc.log_model(

Check warning on line 737 in pymc_marketing/mlflow.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mlflow.py#L737

Added line #L737 was not covered by tests
artifact_path=artifact_path,
python_model=mlflow_clv,
)

run_id = mlflow.active_run().info.run_id
model_uri = f"runs:/{run_id}/{artifact_path}"

Check warning on line 743 in pymc_marketing/mlflow.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mlflow.py#L742-L743

Added lines #L742 - L743 were not covered by tests

if registered_model_name:
mlflow.register_model(model_uri, registered_model_name)

Check warning on line 746 in pymc_marketing/mlflow.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mlflow.py#L745-L746

Added lines #L745 - L746 were not covered by tests


def _load_model(
run_id: str,
cls,
full_model: bool = False,
keep_idata: bool = False,
artifact_path: str = "model",
Expand Down Expand Up @@ -737,7 +778,6 @@
model : mlflow.pyfunc.PyFuncModel | MMM
The loaded MLflow PyFuncModel or MMM model.


Examples
--------
.. code-block:: python
Expand All @@ -759,7 +799,7 @@
run_id=run_id, artifact_path="idata.nc", dst_path=dst_path
)

model = MMM.load(idata_path)
model = cls.load(idata_path)

Check warning on line 802 in pymc_marketing/mlflow.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mlflow.py#L802

Added line #L802 was not covered by tests

if not keep_idata:
_force_load_idata_groups(model.idata)
Expand All @@ -777,6 +817,73 @@
return model


def load_mmm(
run_id: str,
full_model: bool = False,
keep_idata: bool = False,
artifact_path: str = "model",
dst_path: str | None = None,
) -> mlflow.pyfunc.PyFuncModel | MMM:
"""
Load a PyMC-Marketing MMM model from MLflow.

Can either load the full model including the InferenceData, or just the lighter PyFuncModel version.

Parameters
----------
run_id : str
The MLflow run ID from which to load the model.
full_model : bool, default=True
If True, load the full MMM model including the InferenceData.
keep_idata : bool, default=False
If True, keep the downloaded InferenceData saved locally.
artifact_path : str, default="model"
The artifact path within the run where the model is stored.
dst_path : str | None, default=None
The local destination path where the InferenceData will be downloaded.
If None, defaults to "idata_{run_id}" to avoid conflicts when loading multiple models.

Returns
-------
model : mlflow.pyfunc.PyFuncModel | MMM
The loaded MLflow PyFuncModel or MMM model.


Examples
--------
.. code-block:: python

# Load model using run_id
model = load_mmm(run_id="your_run_id", full_model=True, keep_idata=True)
"""
return _load_model(

Check warning on line 859 in pymc_marketing/mlflow.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mlflow.py#L859

Added line #L859 was not covered by tests
run_id=run_id,
full_model=full_model,
keep_idata=keep_idata,
artifact_path=artifact_path,
dst_path=dst_path,
cls=MMM,
)


def load_clv(
run_id: str,
full_model: bool = False,
keep_idata: bool = False,
artifact_path: str = "model",
dst_path: str | None = None,
) -> mlflow.pyfunc.PyFuncModel | CLVModel:
"""Load a PyMC-Marketing CLV model from MLflow."""
return _load_model(

Check warning on line 877 in pymc_marketing/mlflow.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mlflow.py#L877

Added line #L877 was not covered by tests
run_id=run_id,
full_model=full_model,
keep_idata=keep_idata,
artifact_path=artifact_path,
dst_path=dst_path,
cls=CLVModel,
)


@autologging_integration(FLAVOR_NAME)
def autolog(
log_sampler_info: bool = True,
Expand Down
Loading