Advanced Methods in Natural Language Processing - Session 4¶
Text Classification with AG News Corpus¶
This notebook will guide you through different approaches to text classification using the AG News corpus. We will start with a simple baseline model and gradually move towards more complex and sophisticated models.
Table of Contents¶
Part 1: Baseline Pipeline with TF-IDF and Linear Model
- 1.1. Loading and Exploring Data
- 1.2. Feature Extraction with TF-IDF
- 1.3. Training a Linear Model
- 1.4. Model Evaluation
Part 2: LSTM Pipeline with One-Hot Encoding
- 2.1. Preprocessing for LSTM
- 2.2. Building a Bidirectional LSTM Model
- 2.3. Training the LSTM Model
- 2.4. Model Evaluation
Part 3: Word Embedding Add-Ons with Word2Vec
- 3.1. Loading Pre-trained Word2Vec Embeddings
- 3.2. Integrating Word2Vec into LSTM Model
- 3.3. Training and Evaluating the Model
Part 4: Model Explainability (LIME / SHAP)
- 4.1. Why Explainability Matters
- 4.2. Applying LIME to the TF-IDF Model
- 4.3. Comparing Explanation for LSTM with Word2Vec model
Part 0: Metrics Functions to Consider¶
Before diving into the model building and training, it's crucial to establish the metrics we'll use to evaluate our models. In this part, we will define and discuss the different metrics functions that are commonly used in NLP tasks, particularly for text classification:
Accuracy: Measures the proportion of correct predictions among the total number of cases examined. It's a straightforward metric but can be misleading if the classes are imbalanced.
Precision and Recall: Precision measures the proportion of positive identifications that were actually correct, while recall measures the proportion of actual positives that were identified correctly. These metrics are especially important when dealing with imbalanced datasets.
F1 Score: The harmonic mean of precision and recall. It's a good way to show that a classifer has a good balance between precision and recall.
Confusion Matrix: A table used to describe the performance of a classification model on a set of test data for which the true values are known. It allows the visualization of the performance of an algorithm.
ROC and AUC: The receiver operating characteristic curve is a graphical plot that illustrates the diagnostic ability of a binary classifier system. The area under the curve (AUC) represents measure of separability.
We will implement these metrics functions using libraries such as scikit-learn, and they will be used to assess and compare the performance of our different models throughout this exercise.
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# ✅ CLASS TEMPLATE: Metric Comparison Across Models
# Create a reusable Metrics class to:
# - Store evaluation results (accuracy, precision, recall, F1)
# - Compare multiple models side-by-side
# - Plot the results using bar charts
# 🔧 Instructions:
# 1. Define a class `Metrics` that holds a dictionary to store metrics for each method.
# 2. Implement a `.run(y_true, y_pred, method_name)` method that:
# - Computes accuracy, precision, recall, and F1-score.
# - Stores them in the dictionary under the given method name.
# 3. Implement a `.plot()` method that:
# - Creates a 2x2 grid of bar plots (one per metric).
# - Displays the comparison of all methods added via `.run()`.
# 🔍 Hint:
# - Use `sklearn.metrics` functions like `accuracy_score`, `precision_score`, etc.
# - Multiply values by 100 to show percentages.
# - Use `plt.subplots()` for subplot creation.
# - Add value annotations above each bar using `ax.text()`.
# 💡 Once implemented, you can use it like this:
# metrics = Metrics()
# metrics.run(y_true, y_pred_baseline, "Baseline")
# metrics.run(y_true, y_pred_lstm, "LSTM")
# metrics.plot()
class Metrics:
def __init__(self):
def run(self, y_true, y_pred, method_name, average='macro'):
def plot(self):
Part 1: Baseline Pipeline with TF-IDF and Linear Model¶
In this part, we will create a baseline model for text classification. This involves:
1. Loading and Exploring Data:¶
We will load the AG News corpus and perform necessary preprocessing steps like exploring the dataset.
from datasets import load_dataset
# Load the 'ag_news' dataset
dataset = load_dataset("ag_news")
# Explore the structure of the dataset
print(dataset)
Let's create stratified samples for training and validation sets ensuring that each class is represented in proportion to its frequency. It will go faster with just a sample, and we will be able to make tests on validation test before trying to work on the testing set.
from sklearn.model_selection import train_test_split
data = dataset['train']['text']
labels = dataset['train']['label']
test_data = dataset['test']['text']
test_labels = dataset['test']['label']
# Stratified split to create a smaller training and validation set
train_data, valid_data, train_labels, valid_labels = train_test_split(
data, labels, stratify=labels, test_size=0.2, random_state=42
)
# Further split to get 10k and 2k samples respectively
train_data, _, train_labels, _ = train_test_split(
train_data, train_labels, stratify=train_labels, train_size=10000, random_state=42
)
valid_data, _, valid_labels, _ = train_test_split(
valid_data, valid_labels, stratify=valid_labels, train_size=2000, random_state=42
)
from wordcloud import WordCloud
import matplotlib.pyplot as plt
from collections import defaultdict
import nltk
from nltk.corpus import stopwords
# ✅ TASK: Visualize Word Distributions Per Class with Word Clouds
# Goal:
# - Create word clouds for each class label to visually inspect frequent terms.
# - This helps you understand which words dominate each topic (e.g., sports, business, etc.)
# 🧰 Required Libraries:
# - `wordcloud.WordCloud` for generating word clouds
# - `nltk.corpus.stopwords` for filtering out common English words
# - `matplotlib.pyplot` for plotting
# - `collections.defaultdict` for grouping text by label
# 🪜 Instructions:
# 1. Import the required libraries.
# 2. Download and prepare the list of English stopwords using NLTK.
# 3. Create a dictionary `label_data` (e.g., with `defaultdict`) to accumulate all text per class label.
# 4. Set up a 2x2 subplot layout using `plt.subplots` to visualize 4 categories.
from collections import Counter
import matplotlib.pyplot as plt
# ✅ TASK: Visualize Class Distribution with a Pie Chart
# Goal:
# - Create a pie chart to show the proportion of each class in the training dataset.
# - This helps you detect class imbalance, which is important for model training and evaluation.
# 🧰 Required Libraries:
# - `collections.Counter` to count how many examples exist per label.
# - `matplotlib.pyplot` to create the pie chart.
# 🪜 Instructions:
# 1. Use `Counter` to count how many times each label appears in `train_labels`.
# 2. Convert label indices to label names (e.g., 'World', 'Business') using the `labels` dictionary.
# 3. Prepare data for plotting
# 4. Use `plt.pie()` to draw the pie chart.
# 5. Call `plt.axis('equal')` to make sure the pie chart is circular.
# 6. Add a title and display the chart with `plt.show()`.
# 🎯 Output: A pie chart showing the proportion of news categories in your dataset.
2. Feature Extraction with TF-IDF:¶
We will convert the text data into numerical form using the TF-IDF vectorization technique. We will use the Pipeline
class from scikit-learn which is really practical.
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
# ✅ TASK: Build a Baseline Text Classification Pipeline
# Goal:
# - Create a text classification pipeline that transforms text data into TF-IDF features
# and trains a Logistic Regression model on them.
# 🧰 Required Tools:
# - `TfidfVectorizer` to convert raw text into numerical features.
# - `LogisticRegression` as a simple yet effective classifier for baseline comparison.
# - `Pipeline` to combine preprocessing and model training into a single object.
# 🪜 Instructions:
# 1. Create a `Pipeline`
# 2. Fit the pipeline on your training data
# 3. Make predictions on the validation set using `.predict(valid_data)`.
# 🎯 Output:
# - A trained model that you can use to make predictions and evaluate performance.
# - Predicted labels for the validation set stored in `valid_preds`.
metrics_val= Metrics()
metrics_val.run(valid_labels, valid_preds, "basic TF-IDF")
3. Training with Cross Validation:¶
We will train a linear classifier (such as Logistic Regression) using the extracted features, Pipeline
module and cross validation with GridSearchCV
.
from sklearn.model_selection import GridSearchCV
# ✅ TASK: Tune Your TF-IDF + Logistic Regression Pipeline using Grid Search
# Goal:
# - Use `GridSearchCV` to find the best combination of hyperparameters for the pipeline.
# - Improve performance by systematically testing different settings for `TfidfVectorizer`.
# 🧰 Required Tools:
# - `GridSearchCV`: for systematic search over hyperparameter space with cross-validation.
# - `param_grid`: dictionary defining which parameters to tune and their candidate values.
# 🪜 Instructions:
# 1. Define the parameter grid
# 2. Create a `GridSearchCV` object
# 3. Fit the grid search on the training set
4. Model Evaluation:¶
We will evaluate the performance of our model on a separate test set using various metrics.
metrics_val.plot()
Part 2: LSTM Pipeline with One-Hot Encoding¶
In this part, we'll explore a more complex model using LSTM:
1. Preprocessing for LSTM:¶
We'll prepare the text data for LSTM, which involves tokenization and converting words to one-hot encoded vectors.
🔠 Tokenization & Vocabulary¶
We use Keras's Tokenizer
to convert raw text into sequences of integers.
num_words=5000
limits the vocabulary to the most frequent 5000 words.- The tokenizer is fit only on the training data to avoid data leakage.
from tensorflow.keras.preprocessing.text import Tokenizer
# Define vocabulary size (maximum number of words to keep)
vocab_size = 5000 # 🔧 Hyperparameter: Can be increased or decreased based on your dataset size
# Initialize the tokenizer
# Fit the tokenizer on the training text data
🧮 Text to Sequence¶
Now that the tokenizer has learned the vocabulary:
- It transforms each sentence into a list of word indices.
- Words not in the top
vocab_size
are ignored.
Example: "the cat sat"
→ [1, 45, 213]
# Convert training and validation texts to sequences
sequences_train =
sequences_valid =
📐 Padding Sequences¶
Neural networks require fixed-length input, so we:
- Pad shorter sequences with zeros.
- Truncate longer ones to the
max_length
.
padding='post'
adds padding after the sequence (e.g., [45, 213, 0, 0, 0]
).
from tensorflow.keras.preprocessing.sequence import pad_sequences
# Set maximum input length
max_length = 128
# Pad or truncate the sequences to fixed length
padded_sequences_train =
padded_sequences_valid =
🎯 One-Hot Encode Labels¶
For multi-class classification with neural networks, we need to:
- Convert integer class labels to one-hot encoded vectors.
- Class
2
in 4-class problem →[0, 0, 1, 0]
- Class
- This format is required by the softmax output layer in our model.
from tensorflow.keras.utils import to_categorical
# Determine number of unique output classes
num_classes = len(set(train_labels))
# One-hot encode the labels
train_labels_lstm =
valid_labels_lstm =
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Bidirectional, Dense
🧩 Model Layers¶
Embedding
: Transforms integer word indices into dense vectors.Bidirectional LSTM
: Processes the sequence both forward and backward for richer context.Dense
: Final layer with softmax activation for multi-class classification.
Each sentence becomes a sequence of vectors, processed to output a probability over each class.
model = Sequential([
Embedding
Bidirectional
Dense
])
⚙️ Model Compilation¶
We compile the model with:
adam
optimizer (adaptive learning rate)categorical_crossentropy
for multi-class output- Additional metrics:
accuracy
,precision
, andrecall
for deeper evaluation
from tensorflow.keras.metrics import Precision, Recall
model.compile(
optimizer=
loss=
metrics=
)
🧾 Model Summary¶
Let’s take a quick look at the number of parameters, layer types, and output shapes in our model.
✅ This helps validate that the model is structured as intended.
model.summary()
3. Training the LSTM Model:¶
We'll train our LSTM model on the preprocessed text data.
history =
4. Model Evaluation:¶
Similar to Part 1, we will evaluate our model's performance using appropriate metrics.
predictions = model.predict(padded_sequences_valid)
valid_preds = np.argmax(predictions, axis=1)
metrics_val.run(valid_labels, valid_preds, "BiLSTM")
metrics_val.plot()
Part 3: Word Embedding Add-Ons with Word2Vec¶
This part focuses on integrating pre-trained word embeddings into our model.
1. Loading Pre-trained Word2Vec Embeddings:¶
We'll load Word2Vec embeddings pre-trained on a large corpus.
from staticvectors import StaticVectors
word2vec_model = StaticVectors("neuml/word2vec")
2. Integrating Word2Vec into LSTM Model:¶
We'll use these embeddings as inputs to our LSTM model, potentially enhancing its ability to understand context and semantics.
📦 Preparing the Word2Vec Embedding Matrix¶
We will now create an embedding matrix that maps each word in our tokenizer vocabulary to its pre-trained Word2Vec vector.
- We initialize a matrix of zeros with shape
(vocab_size, 300)
- 300 is the dimensionality of Google's pre-trained Word2Vec vectors
🧮 Embedding Matrix Initialization¶
Every row i
in this matrix will store the 300-dimensional Word2Vec vector for the i-th
word in our vocabulary.
embedding_matrix =
🔍 Filling the Embedding Matrix¶
For each word in our tokenizer's vocabulary:
- We look up its corresponding vector from the pre-trained Word2Vec model
- If it's not found (OOV word), we leave the row as zeros
- We use
tqdm
to track progress, which is helpful when processing large vocabularies
for word, i in tqdm(tokenizer.word_index.items()):
🧱 Model with Word2Vec Embeddings¶
- The
Embedding
layer now uses our pre-trained Word2Vec weights trainable=False
means we freeze the embeddings (no updates during training)- Same
Bidirectional LSTM
andDense
classifier as before
⚙️ Compile the Model¶
We compile with:
adam
optimizercategorical_crossentropy
loss for multi-class classification- Accuracy as our evaluation metric
Then we print the model architecture to review.
import numpy as np
from tqdm import tqdm
model = Sequential([
Embedding
Bidirectional
Dense
])
model.compile
model.summary()
3. Training and Evaluating the Model:¶
We'll train our model with these new embeddings and evaluate to see if there's an improvement in performance.
from tensorflow.keras.callbacks import EarlyStopping
# Setup early stopping to stop training when validation loss stops improving
early_stopping = EarlyStopping(
)
history = model.fit(
padded_sequences_train,
train_labels_lstm,
epochs=10,
batch_size=128,
validation_data=(padded_sequences_valid, valid_labels_lstm),
callbacks=[early_stopping]
)
4. Model Evaluation:¶
Similar to previous parts, we will evaluate our model's performance using appropriate metrics.
predictions = model.predict(padded_sequences_valid)
valid_preds = np.argmax(predictions, axis=1)
metrics_val.run(valid_labels, valid_preds, "BiLSTM + W2V")
metrics_val.plot()
Part 4: Model Explainability (LIME / SHAP)¶
As machine learning models become more powerful, they also become more complex and opaque — especially deep learning models like LSTMs or transformer-based architectures. This complexity makes it harder to understand why a model makes a specific prediction.
In this part, we will explore model explainability, a crucial step in building trustworthy and transparent machine learning systems.
4.1 Why Explainability Matters¶
Imagine you’ve built a model that predicts whether a customer is likely to churn, or whether a loan should be approved. Even if your model is 90% accurate, you might be asked:
🤔 "But why did the model make that decision?"
This is where explainability becomes essential.
Why it's important:¶
- Trust: Users are more likely to trust predictions they can understand.
- Debugging: Helps identify spurious correlations or biases in the model.
- Fairness & Ethics: Ensures decisions are not based on sensitive or discriminatory attributes.
- Regulatory Compliance: In some domains (like finance or healthcare), explainability is required by law.
Two Main Categories of Explainability:¶
Global Explanations: Understanding the overall model behavior
Example: Which words generally influence sentiment predictions the most?Local Explanations: Understanding individual predictions
Example: Why was *this* review classified as negative?
We'll focus primarily on local explanations using:
LIME
(for simpler models like TF-IDF + Logistic Regression)
Let’s begin by applying LIME to our TF-IDF baseline model.
4.2 🧪 Applying LIME to the TF-IDF Model¶
Now that we understand why explainability matters, let's start by applying LIME (Local Interpretable Model-agnostic Explanations) to our first model:
➡️ A TF-IDF + Logistic Regression pipeline.
LIME works by slightly perturbing the input text and seeing how the model prediction changes.
From this, it builds a local, interpretable surrogate model (like a linear regression) to approximate the complex model's behavior near that input.
We'll explain:
- A single prediction for a text sample
- Which words had the most impact (positive or negative) on the predicted label
from lime.lime_text import LimeTextExplainer
import shap
# Fixed version of the TF-IDF LIME explainer function
def explain_tfidf_prediction(text_instance, pipeline, class_names):
# Create a LIME text explainer
explainer = LimeTextExplainer(class_names=class_names)
# Get explanation for the prediction
exp = explainer.explain_instance(
text_instance,
pipeline.predict_proba,
num_features=10,
top_labels=len(class_names) # Explain all classes
)
# Display basic information
print(f"Text: {text_instance}")
pred_class = pipeline.predict([text_instance])[0]
print(f"Predicted class: {class_names[pred_class]}")
# Get probabilities for all classes
probs = pipeline.predict_proba([text_instance])[0]
print("\nClass probabilities:")
for i, class_name in enumerate(class_names):
print(f"{class_name}: {probs[i]:.4f}")
# Create visualization for each class
plt.figure(figsize=(20, 15))
# Get the labels that LIME actually explained
top_labels = exp.available_labels()
for i, label_id in enumerate(top_labels):
plt.subplot(2, 2, i+1)
# Get the explanation for this class
exp_list = exp.as_list(label=label_id)
# Extract words and weights
words = [x[0] for x in exp_list]
weights = [x[1] for x in exp_list]
# Sort for better visualization
pairs = sorted(zip(words, weights), key=lambda x: x[1])
words = [x[0] for x in pairs]
weights = [x[1] for x in pairs]
# Create bar chart
colors = ['red' if w < 0 else 'green' for w in weights]
y_pos = np.arange(len(words))
plt.barh(y_pos, weights, color=colors)
plt.yticks(y_pos, words)
plt.title(f"Explanation for class: {class_names[label_id]}")
plt.axvline(x=0, color='black', linestyle='-', alpha=0.5)
plt.tight_layout()
plt.show()
# Print top contributing words for each class
for label_id in top_labels:
print(f"\nTop features for class: {class_names[label_id]}")
exp_list = exp.as_list(label=label_id)
for word, weight in exp_list:
print(f"{word}: {weight:.4f}")
return exp
# Example text from test set
example_text = test_data[0]
class_names = ['World', 'Sports', 'Business', 'Sci/Tech']
lime_exp_tfidf = explain_tfidf_prediction(example_text, grid_search.best_estimator_, class_names)
🧠 Interpretation of LIME Explanation¶
Let's break down what LIME revealed about the model's reasoning for this particular prediction.
✅ Predicted Class: Business¶
LIME shows us the top 10 words that contributed positively or negatively to each possible class (Business, World, Sci/Tech, Sports).
💬 For Class: Business¶
- Words like "Federal", "firm", and "pension" have positive weights, meaning they support the Business prediction.
- The word "talks" actually detracts from the Business prediction (negative weight), suggesting it's a bit ambiguous.
🌍 For Class: World¶
- Interestingly, "talks" strongly supports the World class here.
- Other Business-related terms (e.g., "Federal", "firm") detract from a World prediction.
💡 Insights:¶
- Words like "workers", "Unions", and "say" appear across multiple classes with small influence, showing they’re more generic.
- "talks" is context-dependent – LIME helps us disentangle how the same word can shift meaning depending on the rest of the sentence.
🧭 Takeaway: LIME helps us peek inside the model's black box and see which features are driving predictions. It also shows that certain words may support multiple classes, but with different intensities.
Ready to move on? Let’s now explore:
4.3 📊 Comparing Explanations for LSTM and Word2Vec Models¶
➡️ Here, we'll try to interpret more complex models (like LSTM and Word2Vec-based models) using LIME and compare how their reasoning differs from the simpler TF-IDF model.
def prepare_text_for_lstm(text, tokenizer, max_length):
"""Prepare text input for LSTM model"""
from tensorflow.keras.preprocessing.sequence import pad_sequences
sequences = tokenizer.texts_to_sequences([text])
padded_seq = pad_sequences(sequences, maxlen=max_length)
return padded_seq
def lstm_predict_proba(texts):
"""Prediction function for LIME to use with LSTM model"""
result = np.zeros((len(texts), len(class_names)))
for i, text in enumerate(texts):
padded = prepare_text_for_lstm(text, tokenizer, max_length)
pred = model.predict(padded, verbose=0)
result[i] = pred[0]
return result
def explain_lstm_prediction(text_instance, class_names):
# Create a LIME text explainer
explainer = LimeTextExplainer(class_names=class_names)
# Get explanation for the prediction
exp = explainer.explain_instance(
text_instance,
lstm_predict_proba,
num_features=10,
top_labels=len(class_names) # Explain all classes
)
# Display basic information
print(f"Text: {text_instance}")
padded = prepare_text_for_lstm(text_instance, tokenizer, max_length)
prediction = model.predict(padded, verbose=0)
predicted_class = np.argmax(prediction[0])
print(f"Predicted class: {class_names[predicted_class]}")
# Get probabilities for all classes
probs = prediction[0]
print("\nClass probabilities:")
for i, class_name in enumerate(class_names):
print(f"{class_name}: {probs[i]:.4f}")
# Create visualization for each class
plt.figure(figsize=(20, 15))
# Get the labels that LIME actually explained
top_labels = exp.available_labels()
for i, label_id in enumerate(top_labels):
plt.subplot(2, 2, i+1)
# Get the explanation for this class
exp_list = exp.as_list(label=label_id)
# Extract words and weights
words = [x[0] for x in exp_list]
weights = [x[1] for x in exp_list]
# Sort for better visualization
pairs = sorted(zip(words, weights), key=lambda x: x[1])
words = [x[0] for x in pairs]
weights = [x[1] for x in pairs]
# Create bar chart
colors = ['red' if w < 0 else 'green' for w in weights]
y_pos = np.arange(len(words))
plt.barh(y_pos, weights, color=colors)
plt.yticks(y_pos, words)
plt.title(f"Explanation for class: {class_names[label_id]}")
plt.axvline(x=0, color='black', linestyle='-', alpha=0.5)
plt.tight_layout()
plt.show()
# Print top contributing words for each class
for label_id in top_labels:
print(f"\nTop features for class: {class_names[label_id]}")
exp_list = exp.as_list(label=label_id)
for word, weight in exp_list:
print(f"{word}: {weight:.4f}")
return exp
example_text = test_data[0]
model_exp_lime = explain_lstm_prediction(example_text, class_names)
🔍 LIME Explanation for LSTM Model¶
🧠 Sentence Analyzed:¶
"Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."
✅ Predicted Class: Sci/Tech¶
With a very high probability (78.35%), the LSTM model predicted this text belongs to Sci/Tech.
📊 Class Probabilities:¶
- Sci/Tech: 0.7835
- World: 0.0995
- Business: 0.0839
- Sports: 0.0331
💡 What LIME Reveals¶
Top Features Supporting Sci/Tech:¶
- Common connectors like "they", "after", "are", "with", and "at" surprisingly contribute positively to Sci/Tech.
- Words like "firm" and "talks", which might intuitively relate to Business or World, actually reduce the Sci/Tech probability here.
Across Other Classes:¶
- Most of the same function words ("they", "after", "with") appear as negative contributors to World, Business, and Sports classes.
- Words like "firm" slightly boost the Business class but are downplayed for Sci/Tech.
- In the Sports class, all top words negatively influence the prediction, confirming it’s a poor match.
🧠 Interpretation:¶
Unlike the TF-IDF model, which focused on specific content words, the LSTM model seems to be relying heavily on syntactic or structural features (like function words and word order). This could be due to:
- The sequential nature of LSTM, which captures contextual dependencies
- A possible lack of domain-specific keywords driving this prediction