---
title: "How to A/B test Tensorflow models using Sagemaker Endpoints"
description: "How to deploy multiple model versions as one Sagemaker Endpoint"
author: "Bartosz Mikulski"
author_bio: "Principal AI Engineer & MLOps Architect. I bridge the gap between \"it works in a notebook\" and \"it works for 200 million users.\""
author_url: https://mikulskibartosz.name
author_linkedin: https://www.linkedin.com/in/mikulskibartosz/
author_github: https://github.com/mikulskibartosz
canonical_url: https://mikulskibartosz.name/how-to-ab-test-tensorflow-models-using-sagemaker-endpoints
---

How can we run multiple versions of a Tensorflow model in production at the same time? There are many possible solutions. The simplest option is to deploy those versions as two separate models in Tensorflow Serving and switch between them in the application code. However, that quickly becomes difficult to maintain when we want to do a canary release or A/B test more than two models.

Thankfully, Sagemaker Endpoints simplify A/B testing of machine learning models. We can achieve the desired result using a few lines of code.

In this article, I'll show you how to define multiple ML models, configure them as Sagemaker Endpoint variants, deploy the endpoint, and capture the model results from all deployed versions.

## Dependencies and Imports

To run the code, we need two dependencies:

```
boto3==1.14.12
sagemaker==2.5.3
```

If you prefer to run the deployment script as a step in the AWS Code Pipeline, [take a look at this article](https://mikulskibartosz.name/deploy-tensorflow-using-sagemaker-endpoints).

I assume that the code runs in the environment in which **AWS API key and secret have been provided using environment variables**.

In the deployment script, we have to import Sagemaker and create the session.

```python
import sagemaker
from sagemaker.session import production_variant

from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.model_monitor import DataCaptureConfig

sagemaker_session = sagemaker.Session()

role = 'ARN of the role that has access to Sagemaker and the deployment bucket in S3'
```

## Creating Model Versions

Before we continue, we have to archive the saved Tensorflow models as tar.gz files and store them in an S3 bucket. In the next step, we create two models in AWS Sagemaker using those tar.gz files.

```python
model_version_A = TensorFlowModel(
    name='model-name-version-a',
    role=role,
    entry_point='inference.py',
    source_dir='src',
    model_data='s3://bucket/path/model-version-a.tar.gz',
    framework_version="2.3",
    sagemaker_session=sagemaker_session
)

sagemaker_session.create_model(
    'model-version-a',
    role,
    model_version_A.prepare_container_def(
        instance_type='ml.t2.medium'
    )
)

model_version_A = TensorFlowModel(
    name='model-name-version-b',
    role=role,
    entry_point='inference.py',
    source_dir='src',
    model_data='s3://bucket/path/model-version-b.tar.gz',
    framework_version="2.3",
    sagemaker_session=sagemaker_session
)

sagemaker_session.create_model(
    'model-version-b',
    role,
    model_version_B.prepare_container_def(
        instance_type='ml.t2.medium'
    )
)
```

The `entry_point` is the file containing `input_handler` and `output_handler` functions used to convert the HTTP request to input compatible with Tensorflow Serving and convert the response back to the format expected by the client application. The `source_dir` is the directory where we stored the `inference.py` script. We can also include the `requirements.txt` file in the `source_dir` directory to install additional dependencies.

## Creating Variants

Now, we must define endpoint variants by specifying the model names and the percentage of traffic redirected to every variant (the `initial_weight` parameter):

```python
variantA = production_variant(
    model_name='model-version-a',
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name="VariantA",
    initial_weight=50,
)

variantB = production_variant(
    model_name='model-version-b',
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name="VariantB",
    initial_weight=50,
)
```

## Configuring Data Capture

We want to log the requests and responses to verify which version of the model performs better. Sagemaker Endpoints stores logs of every variant separately in JSON files, and we can log every request by configuring Data Capture with sampling set to 100%.

```python
data_capture_config = DataCaptureConfig(
    enable_capture=True,
    sampling_percentage=100,
    destination_s3_uri='s3://bucket/logs'
)
```

## Deploying the Endpoint

Finally, we can deploy the endpoint with two variants and data capture:

```python
sagemaker_session.endpoint_from_production_variants(
    name='AB-endpoint-with-monitoring',
    production_variants=[variantA, variantB],
    data_capture_config_dict=data_capture_config._to_request_dict()
)
```

## What to do when a Sagemaker Endpoint does not run because of the "Invalid protobuf file" error?

Recently, the Sagemaker backend has been updated, and the following function causes errors in some projects:

```python
def find_model_versions(model_path):
    return [version.lstrip("0") for version in os.listdir(model_path) if version.isnumeric()]

# source: https://github.com/aws/deep-learning-containers/blob/fe4864d0ce873c269da58ad8f3d29a4733cddc80/tensorflow/inference/docker/build_artifacts/sagemaker/tfs_utils.py#L137
```

The Sagemaker backend lists the model versions and removes the leading zeros from the version. The problem is that some people who have only one model version use '0' as the version id. That zero gets trimmed to an empty string, and the Sagemaker Endpoint crashes.

It may be the case in your project if you see the following messages in the Sagemaker Endpoint logs:

```
INFO:__main__:tensorflow serving model config:
model_config_list: {
  config: {
    name: 'saved_model'
    base_path: '/opt/ml/model/tensorflow/saved_model'
    model_platform: 'tensorflow'
    model_version_policy: {
      specific: {
        versions:
      }
    }
  }
}

INFO:__main__:tensorflow version info:
TensorFlow ModelServer: 2.3.0-rc0+dev.sha.no_git
TensorFlow Library: 2.3.0
INFO:__main__:tensorflow serving command: tensorflow_model_server --port=20000 --rest_api_port=20001 --model_config_file=/sagemaker/model-config.cfg --max_num_load_retries=0
INFO:__main__:started tensorflow serving (pid: 16)
INFO:tfs_utils:Trying to connect with model server: http://localhost:20001/v1/models/saved_model
WARNING:urllib3.connectionpool:Retrying (Retry(total=8, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fc8dfc4d6d0>: Failed to establish a new connection: [Errno 111] Connection refused')': /v1/models/saved_model
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:324] Error parsing text-format tensorflow.serving.ModelServerConfig: 9:7: Expected integer, got: }
Failed to start server. Error: Invalid argument: Invalid protobuf file: '/sagemaker/model-config.cfg'
```

I don't know whether an easy fix exists. The simplest solution is to repackage the model `tar.gz` file and change the version id to '1'.