What is text summarization, and why may you need it?
Text summarization is an automated process of generating a concise and cohesive summary of a longer document. We use text summarization to get the main point from a long article quickly or when we must extract essential information from many documents. Text summarization saves time by condensing an article into several bullet points or a few sentence-long descriptions.
Table of Contents
- What is text summarization, and why may you need it?
- How to choose a text summarization model in HuggingFace?
- What is Qwak ML Platform?
- How to deploy the model with Qwak?
- How to use a model deployed in the Qwak ML platform?
We can summarize the text in two ways:
- Extractive summarization - selects key sentences or phrases from the original text and includes them in the summary
- Abstractive summarization - generates a new, shortened version of the text using natural language processing techniques.
When do we need it? Imagine you receive a long email or a tech support ticket from a client. Who should handle the issue? How do you assign the tasks to the correct people or departments? The ticket assignment may involve a text classification model or someone reading the email. However, their work will get easier when the first point of contact doesn’t need to read/process the entire message. Instead, we can send the message to a text summarization service and get a gist of the content.
What else can you achieve with text summarization?
- Automatically generate meeting summaries from a transcribed recording of the meeting.
- Generate social media posts from published articles and ebooks.
- Summarize research reports and legal documents without involving large teams of executive assistants.
How to choose a text summarization model in HuggingFace?
We need to consider at least four factors:
- Extractive vs. abstractive summarization: Do we need to extract key sentences from the document or generate an abstract?
- What languages does the model need to support?
- Do we need a general language or domain-specific model (for example, a model capable of summarizing medical or legal terminology)?
- What are the deployment constraints? Can we use a large and powerful model, or is it too expansive to run?
Right now, I want an extractive summarization model for English news articles, and I want a small model. Therefore, I picked the mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization
model.
It’s a small BART transformer encoder-decoder model. The model was fine-tuned on CNN and Daily Mail. What does it mean?
What is a BERT transformer encoder-decoder model?
The model consists of two main parts: an encoder and a decoder. The encoder processes the input text and produces a numeric representation (called an embedding) of the input. The decoder then generates the output text based on the embeddings.
BERT is a model architecture with a feature called “attention.” The learns to assign a weight (an attention score) to each input token, indicating how important the token is for generating the current output token.
The attention scores are then used to weight the encoder’s output, and this weighted sum is used as additional input to the decoder when generating the next output token. In this way, the attention mechanism allows the model to selectively “focus” on certain parts of the input when generating each output token rather than just using the entire input context.
Of course, the “selective focus” feature is crucial when extracting the most important information from a document.
What is a bidirectional encoder?
A single encoder in natural language processing takes into account the past context of the input text because the encoder processes the input sequence from left to right. However, bidirectional encoders process the input sequence in both directions, from left to right and from right to left, allowing them to capture context from both the “past” (words it has already seen in the current input) and the “future” (text after the currently processed token).
Bidirectional encoders are often implemented using a variant of the transformer architecture called a bidirectional transformer, which consists of two separate transformer encoders. Those encoders process the input sequence in opposite directions. The output of the two encoders is then combined to form the final representation of the input.
Overall, bidirectional encoders improve the performance of natural language processing tasks by allowing the model to consider the full context of the input text. Models with bidirectional encoders are an excellent choice for our use case because we can’t summarize text without considering the context.
Want to build AI systems that actually work?
Download my expert-crafted GenAI Transformation Guide for Data Teams and discover how to properly measure AI performance, set up guardrails, and continuously improve your AI solutions like the pros.
What is Qwak ML Platform?
Qwak is an end-to-end ML platform. It provides infrastructure for model training and deployment, feature stores, model monitoring, and model re-training automation. To use the Qwak platform, we need to provide the training code (or, in this case, download an existing model from an online repository) and the data preprocessing/postprocessing code for inference. Everything else happens automatically.
How to deploy the model with Qwak?
First, we must create a Python (3.7 - 3.9) virtual environment and install the qwak-sdk. I will use pipenv
, but any virtual environment management software will do.
pipenv --python 3.9
pipenv shell
pip install --extra-index-url https://qwak:A3RX55aNSpE8dCV@qwak.jfrog.io/artifactory/api/pypi/qwak-pypi/simple qwak-sdk
If you use Qwak for the first time, you must create an account and run qwak configure
before continuing with the code below!
We must create a project and a model in the Qwak platform. A project groups models. Every model may contain multiple model versions and one deployed endpoint (it may run multiple model versions in the A/B testing or shadow deployment setup). Usually, projects denote use cases, and the models within a project are services supporting the use case.
Let’s create the project using the command line (we can also do this in the web UI):
qwak projects create --project-name "Text Summarization" --project-description "News summarization models"
The command will return an identifier for the project. We have to pass the id as an argument when we add a model to the project:
qwak models create --project-id {project_id} --model-name "bert_summarizer" --model-description "Explain why you need the model"
Now, we can create a model directory. The directory must contain at least three files: __init__.py
, model.py
, and a dependency manager configuration. If we use Conda, the dependency configuration file is called conda.yml
.
In the conda.yml
file, we put the required Python dependencies. In the conda file, we also include the Python version and the Qwak SDK:
name: ArticleSummarization
channels:
- defaults
- conda-forge
dependencies:
- python=3.8
- pip=22.2.2
- pandas=1.1.5
- transformers=4.24.0
- pip:
- --extra-index-url https://qwak:A3RX55aNSpE8dCV@qwak.jfrog.io/artifactory/api/pypi/qwak-pypi/simple qwak-sdk==0.9.115
In the __init__.py
file, we load the model implementation:
from .model import DocumentSummarization
def load_model():
return DocumentSummarization()
Now, the model.py
file. In this file, we have to extend the QwakModelInterface
and implement two methods: build
and predict
. In the build
function, we can train or download an existing model. Whatever we choose, we must remember to store the model as a field in the class. For this model, we download a tokenizer too. Hence, we store both objects as fields.
The build
function runs only once when we build a new version of the model. After the build, the Qwak platform serializes the model object (pickles the entire QwakModelInterface
implementation we provided) and deploys it. The platform passes requests to our predict
function in the deployed container. The function has access to everything we have stored as fields.
import qwak
from qwak.model.base import QwakModelInterface
from qwak.model.adapters import StringInputAdapter
from transformers import pipeline
class DocumentSummarization(QwakModelInterface):
def __init__(self):
self.model = None
self.tokenizer = None
def build(self):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = BertTokenizerFast.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization')
self.model = EncoderDecoderModel.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization').to(device)
@qwak.api(analytics=True)
def predict(self, df) -> pd.DataFrame:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
texts = list(df['text'].values)
result = []
for text in texts:
inputs = self.tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
output = self.model.generate(input_ids, attention_mask=attention_mask)
single_result = self.tokenizer.decode(output[0], skip_special_tokens=True)
result.append({"summary_text": single_result})
result = pd.DataFrame(result)
return result
The predict function will receive the requests as a Pandas Dataframe with a column text
. We return another Pandas DataFrame with a column summary_text
.
How to test the model in the Qwak platform?
The Qwak platform doesn’t require adding tests to your model, but it’s my blog, so we must adequately test everything. We use a model created by someone else, and we don’t have any complex preprocessing code, so it’s sufficient to write a single integration test to verify whether everything works.
However, in the case of text summarization models, the response may be slightly different every time we call the model. Because of that, I will prepare a list of several expected values, and the model’s response must match at least one of them.
We create the tests/it
directory and add the pytest
tests. You don’t need to add pytest
to model dependencies. In the integration test, Qwak will deploy the model in a Docker container and use the SDK client to interact with the model:
from qwak_mock import real_time_client
def test_realtime_api(real_time_client):
articles_to_summarize = [
"""
TODO: put the article here
"""
]
expected_summary = [
"TODO put the summary here",
"TODO another one"
]
actual_result = real_time_client.predict(feature_vector)
found = False
for summary in expected_summary:
if summary in actual_result:
found = True
assert found
Building and deploying the model
Finally, we can build the model using the qwak models build
command in the directory with the model code:
qwak models build --model-id "bert_summarizer" .
We will need the identifier in the next step.
After the build finishes, we can deploy an endpoint with our model. We could click the “deploy” button in the web UI, but Qwak is an SDK-first service, so we can do everything in the command line:
qwak models deploy batch \
--model-id bert_summarizer \
--build-id {the build id} \
--pods 1 \
--cpus 1 \
--memory 2048
After a while, we get a working model endpoint, and we can start sending requests.
How to use a model deployed in the Qwak ML platform?
Qwak offers Python, Java, and Go inferences SDKs. We can add the QWAK inference SDK to dependencies in the client application to call the model. If we can’t add configuration files and a new dependency to the service or it’s written in a language not supported by Qwak, we can use the REST API to interact with the model.
In this article, I will use the REST API.
Before we start, we have to send our Qwak API key to the authentication service and obtain a token. The token is valid for 24 hours. If we used curl
, the command would look like this:
curl --location --request POST 'https://grpc.qwak.ai/api/v1/authentication/qwak-api-key' \
--header 'Content-Type: application/json' \
--data '{"qwakApiKey": "<API_Key>"}'
After obtaining the token, we can send the requests to the model:
curl --location --request POST 'https://models.<environment_name>.qwak.ai/v1/bert_summarizer/predict' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <Auth Token>' \
--data '{"columns":["text"], index:[0], data:[["TEXT TO SUMMARIZE"]]}'