MLflow
What is MLflow?
MLflow is an open source framework created by Databricks to simplify model lifecycle management. It handles model tracking and deployment, and helps with interoperability between different ML tools.
You can find MLflow documentation here, but for a hands-on (and significantly more exciting!) experience check out the tutorial.
MLflow Tracking
One of the key features of MLflow is the ability to track metrics both during the training process and once the model is deployed. By integrating whylogs into the MLflow runtime, you can log data quality metrics as part of the model's pipeline:
- whylogs v0
- whylogs v1
# Note- MLFLow integration is not yet supported by whylogs v1
# whylogs v0.x can be installed via the following
# pip install "whylogs<1.0"
import mlflow
import whylogs
whylogs.enable_mlflow()
import mlflow
import whylogs as why
After enabling the integration, whylogs can be used to profile the data flowing through the pipeline when running MLflow jobs:
- whylogs v0
- whylogs v1
# Note- MLFLow integration is not yet supported by whylogs v1
# whylogs v0.x can be installed via the following
# pip install "whylogs<1.0"
with mlflow.start_run(run_name=”whylogs demo”):
# make the prediction and compute error
predicted_output = model.predict(batch)
mae = mean_absolute_error(actuals, predicted_output)
# standard MLflow tracking APIs
mlflow.log_params(model_params)
mlflow.log_metric("mae", mae)
# profiling the data with whylogs API extension
mlflow.whylogs.log_pandas(batch)
with mlflow.start_run(run_name="whylogs demo"):
# make the prediction and compute error
predicted_output = model.predict(batch)
mae = mean_absolute_error(actuals, predicted_output)
# standard MLflow tracking APIs
mlflow.log_params(model_params)
mlflow.log_metric("mae", mae)
# profiling the data
profile_result = why.log(batch)
profile_result.writer("mlflow").write()
mlflow.end_run()
Once whylogs profiles have been generated, they are stored by MLflow along with all the other artifacts from the run. They can be retrieved from the MLflow backend and explored further:
- whylogs v0
- whylogs v1
# whylogs v0.x can be installed via the following
# pip install "whylogs<1.0"
from whylogs.viz import ProfileVisualizer
# get the profiles associated with the run
mlflow_profiles = whylogs.mlflow.get_experiment_profiles(“experiment_1”)
# visualize the profiles
viz = ProfileVisualizer()
viz.set_profiles(mlflow_profiles)
viz.plot_distribution("free sulfur dioxide", ts_format="%d-%b-%y %H:%M:%S")
# coming soon!
For additional information and in-depth examples, check out the following:
MLflow Serving
On this section we will check two methods to setup an integration between whylogs and MLflow, focusing on a Databricks-based environment.
Method #1
The best way to have whylogs profiling data from an ML endpoint on Serverless Databricks' infrastructure is by deploying a whylogs container separately from the ML model. The ML prediction method should send requests to the container every time it runs predictions. It should not affect too much the requests' latency requirements as long as the payload isn’t too large, and that is the recommended way to go for production use cases.
The way to implement this integration is by modifying the ML model using MLflow's PythonModel
class definition:
class WhylogsModelWrapper(mlflow.pyfunc.PythonModel):
def __init__(self):
self.preprocessor = MinMaxScaler()
self.model = RandomForestClassifier()
def load_context(self, context):
with open(context.artifacts["preprocessor"], "rb") as f:
self.processor = pickle.load(f)
with open(context.artifacts["estimator"], "rb") as f:
self.estimator = pickle.load(f)
def predict(self, context, data):
transformed_data = self.preprocessor.transform(data)
predictions = self.model.predict(transformed_data)
df = pd.DataFrame(data)
df["output_model"] = predictions
response = requests.post(
url="<WHYLOGS_CONTAINER_ENDPOINT>/logs",
data=json.dumps({
"datasetId": "model-5",
"timestamp": 0,
"multiple": df.to_dict(orient="split")
})
)
return predictions
Method #2
Another way to use Databricks' serverless inferencing endpoint is basically by customizing MLflow model using the Rolling Logger from whylogs. This is not perfectly suited for for production use cases just yet, as it locks the main thread of the application. Specifically for the design requirements of Databricks, it is good enough for a POC, and involves less wiring on the user’s side.
Define the customized model as:
import os
import atexit
import pickle
import mlflow
import pandas as pd
import whylogs as why
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MinMaxScaler
class WhylogsModelWrapper(mlflow.pyfunc.PythonModel):
def __init__(self):
self.preprocessor = MinMaxScaler()
self.model = RandomForestClassifier()
self.logger = None
def _start_logger(self):
self.logger = why.logger(mode="rolling", interval=5, when="M",
base_name="message_profile")
self.logger.append_writer("whylabs")
@atexit.register
def close():
if self.logger:
self.logger.close()
def load_context(self, context):
with open(context.artifacts["preprocessor"], "rb") as f:
self.processor = pickle.load(f)
with open(context.artifacts["estimator"], "rb") as f:
self.estimator = pickle.load(f)
def predict(self, context, data):
if not self.logger:
self._start_logger()
transformed_data = self.preprocessor.transform(data)
predictions = self.model.predict(transformed_data)
df = pd.DataFrame(data)
df["output_model"] = predictions
self.logger.log(df)
return predictions
Define Environment Variables
To make writing to WhyLabs possible, whylogs needs to set three environment variables, which will point to the correct model, organization and the API-key. To do that on a Databricks environment, you must:
- Store the environment variables as Databricks secrets
from databricks.sdk import WorkspaceClient
w = WorkspaceClient()
org_id_env = "WHYLABS_DEFAULT_ORG_ID"
dataset_id_env = "WHYLABS_DEFAULT_DATASET_ID"
api_key_env = "WHYLABS_API_KEY"
scope_name = "YOUR_SECRET_SCOPE"
w.secrets.create_scope(scope=scope_name)
w.secrets.put_secret(scope=scope_name, key=org_id_env, string_value="YOUR_ORG_ID")
w.secrets.put_secret(scope=scope_name, key=dataset_id_env, string_value="YOUR_DATASET_ID")
w.secrets.put_secret(scope=scope_name, key=api_key_env, string_value="YOUR_WHYLABS_API_KEY")
- Update the deployed model endpoint with the environment variables
name = "YOUR_MODEL_ENDPOINT_NAME"
url = f"https://<YOUR_DATABRICKS_HOST>/api/2.0/serving-endpoints/{name}/config"
headers = {"Authorization": "Bearer <YOUR_TOKEN>", "Content-Type": "application/json"}
payload = {
"served_models": [
{
"model_name": "YOUR_MODEL_NAME",
"model_version": "YOUR_MODEL_VERSION",
"workload_size": "Small",
"scale_to_zero_enabled": True,
"env_vars": [
{
"env_var_name": "WHYLABS_DEFAULT_ORG_ID",
"secret_scope": "YOUR_SECRET_SCOPE",
"secret_key": "YOUR_SECRET_SCOPE",
},
{
"env_var_name": "WHYLABS_DEFAULT_DATASET_ID",
"secret_scope": "YOUR_SECRET_SCOPE",
"secret_key": "YOUR_SECRET_SCOPE",
},
{
"env_var_name": "WHYLABS_API_KEY",
"secret_scope": "YOUR_SECRET_SCOPE",
"secret_key": "YOUR_SECRET_SCOPE",
}
],
}
]
}
response = requests.put(url, headers=headers, data=json.dumps(payload))
Known limitations
- According to Databricks' docs, if the endpoint has no requests for 30min, it will be tore down
- If it has upscaled due to high traffic, every 5min it will check if new requests were thrown in
- By rotating every 4min, we try to enforce that we don’t lose any information that might come to any of the ML model’s replicas -> but this still relies on how strict the documented 5min-rule applies to each and every case
If the container is shut down and the application is stopped gracefully, we will still rotate any logs that are there and haven’t been written to the platform yet - but we there is no guarantee that this will always be the case.
Get in touch
In this documentation page, we brought some insights on how to integrate WhyLabs with your MLflow models, both at training and inference time, using whylogs profiles, the whylogs container and its built-in WhyLabs writer. If you have questions or wish to understand more on how you can use WhyLabs with your models, contact us at anytime!