Skip to main content

MLflow

What is MLflow?

MLflow is an open source framework created by Databricks to simplify model lifecycle management. It handles model tracking and deployment, and helps with interoperability between different ML tools.

You can find MLflow documentation here, but for a hands-on (and significantly more exciting!) experience check out the tutorial.

MLflow Tracking

One of the key features of MLflow is the ability to track metrics both during the training process and once the model is deployed. By integrating whylogs into the MLflow runtime, you can log data quality metrics as part of the model's pipeline:

# Note- MLFLow integration is not yet supported by whylogs v1
# whylogs v0.x can be installed via the following
# pip install "whylogs<1.0"

import mlflow
import whylogs

whylogs.enable_mlflow()

After enabling the integration, whylogs can be used to profile the data flowing through the pipeline when running MLflow jobs:

# Note- MLFLow integration is not yet supported by whylogs v1
# whylogs v0.x can be installed via the following
# pip install "whylogs<1.0"

with mlflow.start_run(run_name=”whylogs demo”):
# make the prediction and compute error
predicted_output = model.predict(batch)
mae = mean_absolute_error(actuals, predicted_output)

# standard MLflow tracking APIs
mlflow.log_params(model_params)
mlflow.log_metric("mae", mae)

# profiling the data with whylogs API extension
mlflow.whylogs.log_pandas(batch)

Once whylogs profiles have been generated, they are stored by MLflow along with all the other artifacts from the run. They can be retrieved from the MLflow backend and explored further:

# whylogs v0.x can be installed via the following
# pip install "whylogs<1.0"

from whylogs.viz import ProfileVisualizer

# get the profiles associated with the run
mlflow_profiles = whylogs.mlflow.get_experiment_profiles(“experiment_1”)

# visualize the profiles
viz = ProfileVisualizer()
viz.set_profiles(mlflow_profiles)
viz.plot_distribution("free sulfur dioxide", ts_format="%d-%b-%y %H:%M:%S")

whylogs profiles - Distribution Over Time

For additional information and in-depth examples, check out the following:

MLflow Serving

On this section we will check two methods to setup an integration between whylogs and MLflow, focusing on a Databricks-based environment.

Method #1

The best way to have whylogs profiling data from an ML endpoint on Serverless Databricks' infrastructure is by deploying a whylogs container separately from the ML model. The ML prediction method should send requests to the container every time it runs predictions. It should not affect too much the requests' latency requirements as long as the payload isn’t too large, and that is the recommended way to go for production use cases.

The way to implement this integration is by modifying the ML model using MLflow's PythonModel class definition:

class WhylogsModelWrapper(mlflow.pyfunc.PythonModel):
def __init__(self):
self.preprocessor = MinMaxScaler()
self.model = RandomForestClassifier()

def load_context(self, context):
with open(context.artifacts["preprocessor"], "rb") as f:
self.processor = pickle.load(f)
with open(context.artifacts["estimator"], "rb") as f:
self.estimator = pickle.load(f)

def predict(self, context, data):
transformed_data = self.preprocessor.transform(data)
predictions = self.model.predict(transformed_data)

df = pd.DataFrame(data)
df["output_model"] = predictions

response = requests.post(
url="<WHYLOGS_CONTAINER_ENDPOINT>/logs",
data=json.dumps({
"datasetId": "model-5",
"timestamp": 0,
"multiple": df.to_dict(orient="split")
})
)

return predictions

Method #2

Another way to use Databricks' serverless inferencing endpoint is basically by customizing MLflow model using the Rolling Logger from whylogs. This is not perfectly suited for for production use cases just yet, as it locks the main thread of the application. Specifically for the design requirements of Databricks, it is good enough for a POC, and involves less wiring on the user’s side.

Define the customized model as:

import os
import atexit
import pickle

import mlflow
import pandas as pd
import whylogs as why
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MinMaxScaler


class WhylogsModelWrapper(mlflow.pyfunc.PythonModel):
def __init__(self):
self.preprocessor = MinMaxScaler()
self.model = RandomForestClassifier()
self.logger = None

def _start_logger(self):
self.logger = why.logger(mode="rolling", interval=5, when="M",
base_name="message_profile")

self.logger.append_writer("whylabs")

@atexit.register
def close():
if self.logger:
self.logger.close()

def load_context(self, context):
with open(context.artifacts["preprocessor"], "rb") as f:
self.processor = pickle.load(f)
with open(context.artifacts["estimator"], "rb") as f:
self.estimator = pickle.load(f)

def predict(self, context, data):
if not self.logger:
self._start_logger()

transformed_data = self.preprocessor.transform(data)
predictions = self.model.predict(transformed_data)

df = pd.DataFrame(data)
df["output_model"] = predictions

self.logger.log(df)

return predictions

Define Environment Variables

To make writing to WhyLabs possible, whylogs needs to set three environment variables, which will point to the correct model, organization and the API-key. To do that on a Databricks environment, you must:

  1. Store the environment variables as Databricks secrets
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()

org_id_env = "WHYLABS_DEFAULT_ORG_ID"
dataset_id_env = "WHYLABS_DEFAULT_DATASET_ID"
api_key_env = "WHYLABS_API_KEY"

scope_name = "YOUR_SECRET_SCOPE"

w.secrets.create_scope(scope=scope_name)
w.secrets.put_secret(scope=scope_name, key=org_id_env, string_value="YOUR_ORG_ID")
w.secrets.put_secret(scope=scope_name, key=dataset_id_env, string_value="YOUR_DATASET_ID")
w.secrets.put_secret(scope=scope_name, key=api_key_env, string_value="YOUR_WHYLABS_API_KEY")
  1. Update the deployed model endpoint with the environment variables
name = "YOUR_MODEL_ENDPOINT_NAME"
url = f"https://<YOUR_DATABRICKS_HOST>/api/2.0/serving-endpoints/{name}/config"

headers = {"Authorization": "Bearer <YOUR_TOKEN>", "Content-Type": "application/json"}
payload = {
"served_models": [
{
"model_name": "YOUR_MODEL_NAME",
"model_version": "YOUR_MODEL_VERSION",
"workload_size": "Small",
"scale_to_zero_enabled": True,
"env_vars": [
{
"env_var_name": "WHYLABS_DEFAULT_ORG_ID",
"secret_scope": "YOUR_SECRET_SCOPE",
"secret_key": "YOUR_SECRET_SCOPE",
},
{
"env_var_name": "WHYLABS_DEFAULT_DATASET_ID",
"secret_scope": "YOUR_SECRET_SCOPE",
"secret_key": "YOUR_SECRET_SCOPE",
},
{
"env_var_name": "WHYLABS_API_KEY",
"secret_scope": "YOUR_SECRET_SCOPE",
"secret_key": "YOUR_SECRET_SCOPE",
}
],
}
]
}
response = requests.put(url, headers=headers, data=json.dumps(payload))

Known limitations

  • According to Databricks' docs, if the endpoint has no requests for 30min, it will be tore down
  • If it has upscaled due to high traffic, every 5min it will check if new requests were thrown in
  • By rotating every 4min, we try to enforce that we don’t lose any information that might come to any of the ML model’s replicas -> but this still relies on how strict the documented 5min-rule applies to each and every case

If the container is shut down and the application is stopped gracefully, we will still rotate any logs that are there and haven’t been written to the platform yet - but we there is no guarantee that this will always be the case.

Get in touch

In this documentation page, we brought some insights on how to integrate WhyLabs with your MLflow models, both at training and inference time, using whylogs profiles, the whylogs container and its built-in WhyLabs writer. If you have questions or wish to understand more on how you can use WhyLabs with your models, contact us at anytime!

Prefooter Illustration Mobile
Run AI With Certainty
Get started for free
Prefooter Illustration