How to A/B test Tensorflow models using Sagemaker Endpoints

How can we run multiple versions of a Tensorflow model in production at the same time? There are many possible solutions. The simplest option is to deploy those versions as two separate models in Tensorflow Serving and switch between them in the application code. However, that quickly becomes difficult to maintain when we want to do a canary release or A/B test more than two models.

Table of Contents

  1. Dependencies and Imports
  2. Creating Model Versions
  3. Creating Variants
  4. Configuring Data Capture
  5. Deploying the Endpoint
  6. What to do when a Sagemaker Endpoint does not run because of the “Invalid protobuf file” error?

Thankfully, Sagemaker Endpoints simplify A/B testing of machine learning models. We can achieve the desired result using a few lines of code.

In this article, I’ll show you how to define multiple ML models, configure them as Sagemaker Endpoint variants, deploy the endpoint, and capture the model results from all deployed versions.

Dependencies and Imports

To run the code, we need two dependencies:

boto3==1.14.12
sagemaker==2.5.3

If you prefer to run the deployment script as a step in the AWS Code Pipeline, take a look at this article.

I assume that the code runs in the environment in which AWS API key and secret have been provided using environment variables.

In the deployment script, we have to import Sagemaker and create the session.

import sagemaker
from sagemaker.session import production_variant

from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.model_monitor import DataCaptureConfig

sagemaker_session = sagemaker.Session()

role = 'ARN of the role that has access to Sagemaker and the deployment bucket in S3'

Creating Model Versions

Before we continue, we have to archive the saved Tensorflow models as tar.gz files and store them in an S3 bucket. In the next step, we create two models in AWS Sagemaker using those tar.gz files.

model_version_A = TensorFlowModel(
    name='model-name-version-a',
    role=role,
    entry_point='inference.py',
    source_dir='src',
    model_data='s3://bucket/path/model-version-a.tar.gz',
    framework_version="2.3",
    sagemaker_session=sagemaker_session
)

sagemaker_session.create_model(
    'model-version-a',
    role,
    model_version_A.prepare_container_def(
        instance_type='ml.t2.medium'
    )
)

model_version_A = TensorFlowModel(
    name='model-name-version-b',
    role=role,
    entry_point='inference.py',
    source_dir='src',
    model_data='s3://bucket/path/model-version-b.tar.gz',
    framework_version="2.3",
    sagemaker_session=sagemaker_session
)

sagemaker_session.create_model(
    'model-version-b',
    role,
    model_version_B.prepare_container_def(
        instance_type='ml.t2.medium'
    )
)

The entry_point is the file containing input_handler and output_handler functions used to convert the HTTP request to input compatible with Tensorflow Serving and convert the response back to the format expected by the client application. The source_dir is the directory where we stored the inference.py script. We can also include the requirements.txt file in the source_dir directory to install additional dependencies.

Creating Variants

Now, we must define endpoint variants by specifying the model names and the percentage of traffic redirected to every variant (the initial_weight parameter):

variantA = production_variant(
    model_name='model-version-a',
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name="VariantA",
    initial_weight=50,
)

variantB = production_variant(
    model_name='model-version-b',
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name="VariantB",
    initial_weight=50,
)

Configuring Data Capture

We want to log the requests and responses to verify which version of the model performs better. Sagemaker Endpoints stores logs of every variant separately in JSON files, and we can log every request by configuring Data Capture with sampling set to 100%.

data_capture_config = DataCaptureConfig(
    enable_capture=True,
    sampling_percentage=100,
    destination_s3_uri='s3://bucket/logs'
)

Deploying the Endpoint

Finally, we can deploy the endpoint with two variants and data capture:

sagemaker_session.endpoint_from_production_variants(
    name='AB-endpoint-with-monitoring',
    production_variants=[variantA, variantB],
    data_capture_config_dict=data_capture_config._to_request_dict()
)

What to do when a Sagemaker Endpoint does not run because of the “Invalid protobuf file” error?

Recently, the Sagemaker backend has been updated, and the following function causes errors in some projects:

def find_model_versions(model_path):
    return [version.lstrip("0") for version in os.listdir(model_path) if version.isnumeric()]

# source: https://github.com/aws/deep-learning-containers/blob/fe4864d0ce873c269da58ad8f3d29a4733cddc80/tensorflow/inference/docker/build_artifacts/sagemaker/tfs_utils.py#L137

The Sagemaker backend lists the model versions and removes the leading zeros from the version. The problem is that some people who have only one model version use ‘0’ as the version id. That zero gets trimmed to an empty string, and the Sagemaker Endpoint crashes.

It may be the case in your project if you see the following messages in the Sagemaker Endpoint logs:

INFO:__main__:tensorflow serving model config:
model_config_list: {
  config: {
    name: 'saved_model'
    base_path: '/opt/ml/model/tensorflow/saved_model'
    model_platform: 'tensorflow'
    model_version_policy: {
      specific: {
        versions:
      }
    }
  }
}

INFO:__main__:tensorflow version info:
TensorFlow ModelServer: 2.3.0-rc0+dev.sha.no_git
TensorFlow Library: 2.3.0
INFO:__main__:tensorflow serving command: tensorflow_model_server --port=20000 --rest_api_port=20001 --model_config_file=/sagemaker/model-config.cfg --max_num_load_retries=0
INFO:__main__:started tensorflow serving (pid: 16)
INFO:tfs_utils:Trying to connect with model server: http://localhost:20001/v1/models/saved_model
WARNING:urllib3.connectionpool:Retrying (Retry(total=8, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fc8dfc4d6d0>: Failed to establish a new connection: [Errno 111] Connection refused')': /v1/models/saved_model
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:324] Error parsing text-format tensorflow.serving.ModelServerConfig: 9:7: Expected integer, got: }
Failed to start server. Error: Invalid argument: Invalid protobuf file: '/sagemaker/model-config.cfg'

I don’t know whether an easy fix exists. The simplest solution is to repackage the model tar.gz file and change the version id to ‘1’.

Older post

How to predict the value of time series using Tensorflow and RNN

How to train the RNN model in Tensorflow to predict time series?

Newer post

How to add custom preprocessing code to a Sagemaker Endpoint running a Tensorflow model

How to customize input/output of a Sagemaker Endpoint running a Tensorflow model

Are you looking for an experienced AI consultant? Do you need assistance with your RAG or Agentic Workflow?
Schedule a call, send me a message on LinkedIn. Schedule a call or send me a message on LinkedIn

>