Zero-Shot Text Classification with BERT: No Training Data Required!
NLP
Text Classification
BERT
Zero-Shot Learning
Transformers
Hugging Face
Python
AI

Zero-Shot Text Classification with BERT: No Training Data Re...

A practical guide to performing text classification using BERT without any labeled training data, leveraging the power of pre-trained language models.

March 23, 2024
3 minutes

Zero-Shot Text Classification with BERT: No Training Data Required! ✨

Introduction

Traditional text classification typically requires a substantial amount of labeled data to train a model. You need examples of text belonging to each category you want to classify. However, what if you don't have any labeled data? Or what if you need to classify text into categories you didn't anticipate during training? This is where zero-shot learning comes in.

Zero-shot learning allows a model to perform a task it wasn't explicitly trained for. In the context of text classification, this means classifying text into categories without having seen any labeled examples of those categories. This seemingly magical feat is achieved by leveraging the rich semantic understanding embedded within large, pre-trained language models like BERT (Bidirectional Encoder Representations from Transformers).

This guide will explain how zero-shot text classification with BERT works and demonstrate how to implement it using the Hugging Face Transformers library in Python.

How Zero-Shot Classification with BERT Works

The key idea behind zero-shot classification with BERT (and similar models) is to frame the classification task as a Natural Language Inference (NLI) problem. NLI involves determining the relationship between two sentences: a premise and a hypothesis. The relationship can be one of:

  • Entailment: The hypothesis is true given the premise.
  • Contradiction: The hypothesis is false given the premise.
  • Neutral: The hypothesis could be either true or false given the premise.

In zero-shot classification, we treat the text we want to classify as the premise, and we create a hypothesis for each potential category. The hypothesis is a statement that describes the category. For example:

  • Text (Premise): "The company announced record profits for the third quarter."
  • Category 1 Hypothesis: "This text is about business."
  • Category 2 Hypothesis: "This text is about sports."
  • Category 3 Hypothesis: "This text is about politics."

We then use a pre-trained NLI model (often a fine-tuned BERT model) to predict the relationship between the text (premise) and each hypothesis. The category whose hypothesis has the highest entailment score is considered the predicted category.

Why Use BERT for Zero-Shot Classification?

BERT, and other Transformer-based models, are excellent candidates for zero-shot learning because:

  1. Pre-training on Massive Datasets: BERT is pre-trained on a vast amount of text data, learning rich representations of language and world knowledge. This pre-training allows it to understand the relationships between words, sentences, and concepts, even without explicit training on a specific classification task.
  2. Contextualized Embeddings: BERT produces contextualized word embeddings, meaning the representation of a word changes depending on the surrounding words. This is crucial for understanding the nuances of language and accurately determining the relationship between the premise and hypothesis.
  3. Fine-tuning for NLI: Many publicly available BERT models have been fine-tuned on NLI datasets, making them particularly well-suited for this task.

Implementation with Hugging Face Transformers

The Hugging Face Transformers library provides a simple and convenient way to perform zero-shot classification. We'll use the pipeline API for ease of use.

Step 1: Install Transformers

If you haven't already, install the Transformers library:

1
pip install transformers

Step 2: Load the Zero-Shot Classification Pipeline

1
from transformers import pipeline
2
3
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

We're using the facebook/bart-large-mnli model, a BART model fine-tuned on the MultiNLI dataset, which is a large-scale NLI dataset. This model is known to perform well on zero-shot classification tasks. Other good options include:

  • joeddav/xlm-roberta-large-xnli: A multilingual model.
  • typeform/distilbert-base-uncased-mnli: A smaller, faster (but potentially less accurate) model.

Step 3: Classify Text

Now, let's classify some text. We provide the text to classify and a list of candidate labels (categories).

1
sequence_to_classify = "The company announced record profits for the third quarter."
2
candidate_labels = ["business", "sports", "politics"]
3
4
result = classifier(sequence_to_classify, candidate_labels)
5
print(result)

The output will be a dictionary containing:

  • sequence: The input text.
  • labels: The candidate labels, sorted by their predicted probability.
  • scores: The probabilities (entailment scores) for each label.

The output might look something like this:

1
{'sequence': 'The company announced record profits for the third quarter.',
2
'labels': ['business', 'politics', 'sports'],
3
'scores': [0.9638, 0.0287, 0.0075]}

In this case, "business" is correctly identified as the most likely category.

Step 4: Multi-Label Classification (Optional)

By default, the pipeline assumes you want to perform single-label classification (assign only one label to the text). If you want to allow for multiple labels, you can set multi_label=True:

1
sequence_to_classify = "The new policy sparked protests and debates in parliament."
2
candidate_labels = ["business", "politics", "social issues", "economics"]
3
4
result = classifier(sequence_to_classify, candidate_labels, multi_label=True)
5
print(result)

The output will now treat each label independently. The scores represent the probability of each label being relevant, and they won't necessarily sum to 1. You might get something like:

1
{'sequence': 'The new policy sparked protests and debates in parliament.',
2
'labels': ['politics', 'social issues', 'economics', 'business'],
3
'scores': [0.9872, 0.9541, 0.6128, 0.0315]}

Here, both "politics" and "social issues" are correctly identified as highly likely categories.

Step 5: Using a Custom Hypothesis Template (Optional)

You can customize the hypothesis template used by the classifier. By default, it uses "This text is about ." (where {} is replaced by the candidate label). You can change this to be more specific or to better reflect the nuances of your task.

1
sequence_to_classify = "The midfielder scored a stunning goal in the final minutes."
2
candidate_labels = ["soccer", "basketball", "tennis"]
3
hypothesis_template = "The topic of this article is {}."
4
5
result = classifier(sequence_to_classify, candidate_labels, hypothesis_template=hypothesis_template)
6
print(result)

This allows for greater flexibility and can sometimes improve accuracy.

Step 6: Handling longer sequences (Optional)

If you have text sequences longer than the model's maximum sequence length (usually 512 tokens for BERT-based models), you might need to truncate or split the text. The pipeline can handle truncation automatically, but you can also control it:

1
very_long_sequence = "..." # A very long text
2
candidate_labels = ["...", "...", "..."]
3
4
# Truncate to the maximum length supported by the model.
5
result = classifier(very_long_sequence, candidate_labels, truncation=True)
6
7
# Or, split the text into chunks and classify each chunk separately.
8
# (This requires more custom code to handle the splitting and aggregation of results.)

It's generally better to truncate strategically if possible. For instance, you might truncate from the end if the most important information is likely to be at the beginning of the text.

Step 7: Batch Processing (Optional)

For improved performance on many sequences, use batching:

1
sequences = [
2
"The company announced record profits.",
3
"The team won the championship.",
4
"The president gave a speech."
5
]
6
candidate_labels = ["business", "sports", "politics"]
7
8
results = classifier(sequences, candidate_labels)
9
10
for result in results:
11
print(result)

The pipeline automatically handles batching when you pass a list of sequences.

Limitations and Considerations

  • Performance: While zero-shot classification is impressive, it typically won't achieve the same accuracy as a model fine-tuned on a large, labeled dataset for the specific categories you're interested in.
  • Label Choice: The choice of candidate labels significantly impacts the results. Labels should be clear, concise, and unambiguous. Avoid overlapping or overly broad categories.
  • Hypothesis Formulation: The way you phrase the hypothesis can also affect performance. Experiment with different templates to see what works best.
  • Computational Cost: While zero-shot classification avoids the cost of training, it still requires inference using a large language model, which can be computationally expensive, especially for very long texts or large numbers of candidate labels.
  • Language: The facebook/bart-large-mnli model is trained primarily on English data. While it can handle other languages to some extent, performance will likely be better for English text. For other languages, consider using a multilingual model like joeddav/xlm-roberta-large-xnli.

Conclusion

Zero-shot text classification with BERT offers a powerful and flexible way to categorize text without the need for labeled training data. The Hugging Face Transformers library makes this technique easily accessible. While it may not always match the performance of a fully supervised model, it's an invaluable tool for rapid prototyping, exploring new datasets, and handling situations where labeled data is scarce or unavailable. By understanding the principles of NLI and carefully crafting your candidate labels and hypotheses, you can leverage the power of pre-trained language models to perform surprisingly accurate text classification, even in the absence of training examples. ✅🔥

Share
Comments are disabled