
Advanced Methods in Natural Language Processing - Session 4¶
Text Classification with AG News Corpus¶
This notebook will guide you through different approaches to text classification using the AG News corpus. We will start with a simple baseline model and gradually move towards more complex and sophisticated models.
Table of Contents¶
Part 1: Baseline Pipeline with TF-IDF and Linear Model
- 1.1. Loading and Exploring Data
- 1.2. Feature Extraction with TF-IDF
- 1.3. Training a Linear Model
- 1.4. Model Evaluation
Part 2: LSTM Pipeline with One-Hot Encoding
- 2.1. Preprocessing for LSTM
- 2.2. Building a Bidirectional LSTM Model
- 2.3. Training the LSTM Model
- 2.4. Model Evaluation
Part 3: Word Embedding Add-Ons with Word2Vec
- 3.1. Loading Pre-trained Word2Vec Embeddings
- 3.2. Integrating Word2Vec into LSTM Model
- 3.3. Training and Evaluating the Model
Part 4: Model Explainability (LIME / SHAP)
- 4.1. Why Explainability Matters
- 4.2. Applying LIME to the TF-IDF Model
- 4.3. Comparing Explanation for LSTM with Word2Vec model
Part 0: Metrics Functions to Consider¶
Before diving into the model building and training, it's crucial to establish the metrics we'll use to evaluate our models. In this part, we will define and discuss the different metrics functions that are commonly used in NLP tasks, particularly for text classification:
Accuracy: Measures the proportion of correct predictions among the total number of cases examined. It's a straightforward metric but can be misleading if the classes are imbalanced.
Precision and Recall: Precision measures the proportion of positive identifications that were actually correct, while recall measures the proportion of actual positives that were identified correctly. These metrics are especially important when dealing with imbalanced datasets.
F1 Score: The harmonic mean of precision and recall. It's a good way to show that a classifer has a good balance between precision and recall.
Confusion Matrix: A table used to describe the performance of a classification model on a set of test data for which the true values are known. It allows the visualization of the performance of an algorithm.
ROC and AUC: The receiver operating characteristic curve is a graphical plot that illustrates the diagnostic ability of a binary classifier system. The area under the curve (AUC) represents measure of separability.
We will implement these metrics functions using libraries such as scikit-learn, and they will be used to assess and compare the performance of our different models throughout this exercise.
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
class Metrics:
def __init__(self):
self.results = {}
def run(self, y_true, y_pred, method_name, average='macro'):
# Calculate metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average=average)
recall = recall_score(y_true, y_pred, average=average)
f1 = f1_score(y_true, y_pred, average=average)
# Store results
self.results[method_name] = {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
def plot(self):
# Create subplots
fig, axs = plt.subplots(2, 2, figsize=(15, 10))
# Plot each metric
for i, metric in enumerate(['accuracy', 'precision', 'recall', 'f1']):
ax = axs[i//2, i%2]
values = [res[metric] * 100 for res in self.results.values()]
ax.bar(self.results.keys(), values)
ax.set_title(metric)
ax.set_ylim(0, 100)
# Add values on the bars
for j, v in enumerate(values):
ax.text(j, v + 0.02, f"{v:.2f}", ha='center', va='bottom')
plt.tight_layout()
plt.show()
Part 1: Baseline Pipeline with TF-IDF and Linear Model¶
In this part, we will create a baseline model for text classification. This involves:
1. Loading and Exploring Data:¶
We will load the AG News corpus and perform necessary preprocessing steps like exploring the dataset.
from datasets import load_dataset
# Load the 'ag_news' dataset
dataset = load_dataset("ag_news")
# Explore the structure of the dataset
print(dataset)
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
README.md: 0.00B [00:00, ?B/s]
data/train-00000-of-00001.parquet: 0%| | 0.00/18.6M [00:00<?, ?B/s]
data/test-00000-of-00001.parquet: 0%| | 0.00/1.23M [00:00<?, ?B/s]
Generating train split: 0%| | 0/120000 [00:00<?, ? examples/s]
Generating test split: 0%| | 0/7600 [00:00<?, ? examples/s]
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 120000
})
test: Dataset({
features: ['text', 'label'],
num_rows: 7600
})
})
Let's create stratified samples for training and validation sets ensuring that each class is represented in proportion to its frequency. It will go faster with just a sample, and we will be able to make tests on validation test before trying to work on the testing set.
from sklearn.model_selection import train_test_split
data = dataset['train']['text']
labels = dataset['train']['label']
test_data = dataset['test']['text']
test_labels = dataset['test']['label']
# Stratified split to create a smaller training and validation set
train_data, valid_data, train_labels, valid_labels = train_test_split(
data, labels, stratify=labels, test_size=0.2, random_state=42
)
# Further split to get 10k and 2k samples respectively
train_data, _, train_labels, _ = train_test_split(
train_data, train_labels, stratify=train_labels, train_size=10000, random_state=42
)
valid_data, _, valid_labels, _ = train_test_split(
valid_data, valid_labels, stratify=valid_labels, train_size=2000, random_state=42
)
from wordcloud import WordCloud
import matplotlib.pyplot as plt
from collections import defaultdict
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
labels = {0: 'World', 1: 'Sports',
2: 'Business', 3: 'Sci/Tech'}
# Prepare data for wordclouds
label_data = defaultdict(lambda: '')
for text, label in zip(train_data, train_labels):
label_data[label] += text
# Generate and plot wordclouds for each label
fig, axs = plt.subplots(2, 2, figsize=(10, 6)) # Create 2x2 subplots
axs = axs.flatten() # Flatten the axis array
for ax, (label, text) in zip(axs, label_data.items()):
wordcloud = WordCloud(stopwords=stop_words, background_color='white').generate(text)
ax.imshow(wordcloud, interpolation='bilinear')
ax.set_title('WordCloud for Label {}'.format(labels.get(label)))
ax.axis('off')
plt.tight_layout()
plt.show()
[nltk_data] Downloading package stopwords to [nltk_data] /Users/agomberto/nltk_data... [nltk_data] Unzipping corpora/stopwords.zip.
from collections import Counter
import matplotlib.pyplot as plt
# Count the frequency of each label
label_counts = Counter(train_labels)
# Data to plot
_labels = [labels.get(lab) for lab in label_counts.keys()]
sizes = label_counts.values()
colors = ['gold', 'yellowgreen', 'lightcoral', 'lightskyblue']
# Plotting the pie chart
plt.pie(sizes, labels=_labels, colors=colors, autopct='%1.1f%%', startangle=140)
plt.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle.
plt.title('Proportion of Each Label')
plt.show()
2. Feature Extraction with TF-IDF:¶
We will convert the text data into numerical form using the TF-IDF vectorization technique. We will use the Pipeline class from scikit-learn which is really practical.
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
# Create a pipeline with TF-IDF and Logistic Regression
pipeline = Pipeline([
('tfidf', TfidfVectorizer(ngram_range=(1, 2),
min_df=5,
stop_words='english')),
('clf', LogisticRegression(solver='lbfgs')),
])
# Fit the pipeline on the training data
pipeline.fit(train_data, train_labels)
valid_preds = pipeline.predict(valid_data)
metrics_val= Metrics()
metrics_val.run(valid_labels, valid_preds, "basic TF-IDF")
3. Training with Cross Validation:¶
We will train a linear classifier (such as Logistic Regression) using the extracted features, Pipeline module and cross validation with GridSearchCV.
from sklearn.model_selection import GridSearchCV
# Define the parameter grid
param_grid = {
'tfidf__min_df': [1, 2, 5], # Example values, you can choose others
'tfidf__ngram_range': [(1, 1), (1, 2), (1, 3), (2, 2)] # Unigrams, bigrams or both
}
# Create a GridSearchCV object
grid_search = GridSearchCV(pipeline, param_grid, cv=5, n_jobs=-1, verbose=1)
# Fit the grid search to the data
grid_search.fit(train_data, train_labels)
# Best parameters found by grid search
print(f'Best Parameters: {grid_search.best_params_}')
# Evaluate on the validation set
valid_preds = grid_search.predict(valid_data)
metrics_val.run(valid_labels, valid_preds, "CV-ed TF-IDF")
Fitting 5 folds for each of 12 candidates, totalling 60 fits
Best Parameters: {'tfidf__min_df': 2, 'tfidf__ngram_range': (1, 2)}
4. Model Evaluation:¶
We will evaluate the performance of our model on a separate test set using various metrics.
metrics_val.plot()
Part 2: LSTM Pipeline with One-Hot Encoding¶
In this part, we'll explore a more complex model using LSTM:
1. Preprocessing for LSTM:¶
We'll prepare the text data for LSTM, which involves tokenization and converting words to one-hot encoded vectors.
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
# Parameters
vocab_size = 5000 # This is a hyperparameter, adjust as needed
max_length = 128 # This is another hyperparameter
# Initialize and fit the tokenizer
tokenizer = Tokenizer(num_words=vocab_size)
tokenizer.fit_on_texts(train_data)
# Convert texts to sequences of integers
sequences_train = tokenizer.texts_to_sequences(train_data)
sequences_valid = tokenizer.texts_to_sequences(valid_data)
# Pad sequences to the same length
padded_sequences_train = pad_sequences(sequences_train, maxlen=max_length,
padding='post', truncating='post')
padded_sequences_valid = pad_sequences(sequences_valid, maxlen=max_length,
padding='post', truncating='post')
# Assuming train_labels are integer labels
num_classes = len(set(train_labels)) # Determine the number of unique classes
# Convert labels to one-hot vectors
train_labels_lstm = to_categorical(train_labels, num_classes=num_classes)
valid_labels_lstm = to_categorical(valid_labels, num_classes=num_classes)
2. Building a Bidirectional LSTM Model:¶
We'll design a neural network with a Bidirectional LSTM layer to capture context from both directions in the text.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Embedding, LSTM, Bidirectional, Dense
from tensorflow.keras.metrics import Precision, Recall
model = Sequential([
Input(shape=(max_length,)), # max_length = fixed sequence length (128 tokens after padding)
Embedding(vocab_size, output_dim=64), # vocab_size = 5000 words; output_dim=64 = size of the dense vector per word (hyperparameter, larger = more expressive but more params)
Bidirectional(LSTM(64)), # 64 = number of LSTM hidden units (size of the internal state); Bidirectional doubles the output -> 128 features
Dense(num_classes, activation='softmax') # num_classes = 4 (World, Sports, Business, Sci/Tech); softmax -> probability distribution over classes
])
model.compile(optimizer='adam', loss='categorical_crossentropy',
metrics=['accuracy', Precision(), Recall()])
model.summary()
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ embedding_1 (Embedding) │ (None, 128, 64) │ 320,000 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ bidirectional_1 (Bidirectional) │ (None, 128) │ 66,048 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 4) │ 516 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 386,564 (1.47 MB)
Trainable params: 386,564 (1.47 MB)
Non-trainable params: 0 (0.00 B)
🔢 Where Do the Parameters Come From?¶
Total: 386,564 trainable parameters. Let's break it down layer by layer.
1. Embedding → 320,000 params¶
Formula: vocab_size × output_dim
$$ 5000 \times 64 = 320{,}000 $$
Each of the 5000 vocabulary words gets its own learnable 64-dimensional vector. No bias here — just a lookup table.
2. Bidirectional(LSTM(64)) → 66,048 params¶
A single LSTM cell has 4 internal gates (input, forget, cell, output). Each gate has its own weight matrix on the input, its own recurrent weight matrix on the hidden state, and a bias.
Formula for one LSTM direction:
$$ 4 \times \big( (\text{input\_dim} \times \text{units}) + (\text{units} \times \text{units}) + \text{units} \big) $$
With input_dim = 64 (embedding output) and units = 64:
$$ 4 \times (64 \times 64 + 64 \times 64 + 64) = 4 \times (4096 + 4096+ 64) = 4 \times 8256 = 33{,}024 $$
Bidirectional runs two LSTMs (forward + backward), so:
$$ 2 \times 33{,}024 = 66{,}048 $$
The output shape becomes (None, 128) because the two 64-dim hidden states are concatenated.
3. Dense(4, softmax) → 516 params¶
Formula: (input_dim × units) + units (weights + bias)
$$ (128 \times 4) + 4 = 512 + 4 = 516 $$
The input is 128 (the BiLSTM output) and the output is 4 (one logit per class).
💡 Key Takeaways¶
- The embedding layer dominates the parameter count (~83% of the total). Increasing
vocab_sizeoroutput_dimblows up the model size fast. - The LSTM cost scales as
4 × units²(recurrent weights dominate). Doublingunitsroughly quadruples LSTM params. - The Dense layer is tiny because the number of classes is small.
3. Training the LSTM Model:¶
We'll train our LSTM model on the preprocessed text data.
history = model.fit(
padded_sequences_train,
train_labels_lstm,
epochs=10,
batch_size=128,
validation_data=(padded_sequences_valid, valid_labels_lstm)
)
Epoch 1/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 8s 79ms/step - accuracy: 0.5454 - loss: 1.1141 - precision_1: 0.8685 - recall_1: 0.1492 - val_accuracy: 0.7910 - val_loss: 0.6697 - val_precision_1: 0.9185 - val_recall_1: 0.4900 Epoch 2/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 9s 118ms/step - accuracy: 0.8509 - loss: 0.4826 - precision_1: 0.9034 - recall_1: 0.7490 - val_accuracy: 0.8670 - val_loss: 0.3931 - val_precision_1: 0.8972 - val_recall_1: 0.8380 Epoch 3/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 138ms/step - accuracy: 0.9129 - loss: 0.2736 - precision_1: 0.9263 - recall_1: 0.8973 - val_accuracy: 0.8825 - val_loss: 0.3497 - val_precision_1: 0.8990 - val_recall_1: 0.8635 Epoch 4/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 140ms/step - accuracy: 0.9426 - loss: 0.1873 - precision_1: 0.9506 - recall_1: 0.9357 - val_accuracy: 0.8765 - val_loss: 0.3649 - val_precision_1: 0.8885 - val_recall_1: 0.8645 Epoch 5/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 136ms/step - accuracy: 0.9648 - loss: 0.1311 - precision_1: 0.9684 - recall_1: 0.9601 - val_accuracy: 0.8730 - val_loss: 0.3822 - val_precision_1: 0.8826 - val_recall_1: 0.8685 Epoch 6/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 134ms/step - accuracy: 0.9748 - loss: 0.0978 - precision_1: 0.9776 - recall_1: 0.9726 - val_accuracy: 0.8725 - val_loss: 0.4049 - val_precision_1: 0.8803 - val_recall_1: 0.8640 Epoch 7/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 140ms/step - accuracy: 0.9825 - loss: 0.0757 - precision_1: 0.9838 - recall_1: 0.9813 - val_accuracy: 0.8575 - val_loss: 0.4647 - val_precision_1: 0.8642 - val_recall_1: 0.8525 Epoch 8/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 141ms/step - accuracy: 0.9869 - loss: 0.0571 - precision_1: 0.9877 - recall_1: 0.9861 - val_accuracy: 0.8650 - val_loss: 0.4758 - val_precision_1: 0.8727 - val_recall_1: 0.8600 Epoch 9/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 142ms/step - accuracy: 0.9907 - loss: 0.0453 - precision_1: 0.9910 - recall_1: 0.9903 - val_accuracy: 0.8535 - val_loss: 0.5666 - val_precision_1: 0.8564 - val_recall_1: 0.8495 Epoch 10/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 142ms/step - accuracy: 0.9888 - loss: 0.0445 - precision_1: 0.9893 - recall_1: 0.9886 - val_accuracy: 0.8525 - val_loss: 0.5167 - val_precision_1: 0.8565 - val_recall_1: 0.8475
4. Model Evaluation:¶
Similar to Part 1, we will evaluate our model's performance using appropriate metrics.
predictions = model.predict(padded_sequences_valid)
valid_preds = np.argmax(predictions, axis=1)
metrics_val.run(valid_labels, valid_preds, "BiLSTM")
63/63 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step
metrics_val.plot()
Part 3: Word Embedding Add-Ons with Word2Vec¶
This part focuses on integrating pre-trained word embeddings into our model.
1. Loading Pre-trained Word2Vec Embeddings:¶
We'll load Word2Vec embeddings pre-trained on a large corpus.
from staticvectors import StaticVectors
word2vec_model = StaticVectors("neuml/word2vec")
config.json: 0%| | 0.00/167 [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/3.60G [00:00<?, ?B/s]
vocab.json: 0%| | 0.00/79.3M [00:00<?, ?B/s]
2. Integrating Word2Vec into LSTM Model:¶
We'll use these embeddings as inputs to our LSTM model, potentially enhancing its ability to understand context and semantics.
import numpy as np
from tqdm import tqdm
import tensorflow as tf
# Initialize the embedding matrix
embedding_matrix = np.zeros((vocab_size, 300)) # 300 is the dimensionality of GoogleNews vectors
for word, i in tqdm(tokenizer.word_index.items()):
if i >= vocab_size:
continue
try:
vec = word2vec_model.embeddings([word])[0] # take the first (and only) row
if np.any(vec): # skip zero/garbage fallback vectors
embedding_matrix[i] = vec
except (KeyError, ValueError):
continue
# Define the model
model = Sequential([
Input(shape=(max_length,)),
Embedding(vocab_size, 300, embeddings_initializer=tf.keras.initializers.Constant(embedding_matrix), trainable=False), # Set trainable to False
Bidirectional(LSTM(64)),
Dense(num_classes, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
100%|██████████| 25874/25874 [00:00<00:00, 340156.16it/s]
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ embedding_2 (Embedding) │ (None, 128, 300) │ 1,500,000 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ bidirectional_2 (Bidirectional) │ (None, 128) │ 186,880 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 4) │ 516 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 1,687,396 (6.44 MB)
Trainable params: 187,396 (732.02 KB)
Non-trainable params: 1,500,000 (5.72 MB)
3. Training and Evaluating the Model:¶
We'll train our model with these new embeddings and evaluate to see if there's an improvement in performance.
from tensorflow.keras.callbacks import EarlyStopping
# Setup early stopping to stop training when validation loss stops improving
early_stopping = EarlyStopping(
monitor='val_loss', # Monitor validation loss
patience=5, # How many epochs to wait after min has been hit
verbose=1, # Verbosity level
mode='min', # Mode for the monitored quantity (minimizing loss)
restore_best_weights=True # Restore model weights from the epoch with the best value of the monitored quantity
)
history = model.fit(
padded_sequences_train,
train_labels_lstm,
epochs=10,
batch_size=128,
validation_data=(padded_sequences_valid, valid_labels_lstm),
callbacks=[early_stopping]
)
Epoch 1/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 10s 102ms/step - accuracy: 0.6725 - loss: 0.8980 - val_accuracy: 0.8235 - val_loss: 0.5224 Epoch 2/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 10s 132ms/step - accuracy: 0.8382 - loss: 0.4740 - val_accuracy: 0.8425 - val_loss: 0.4604 Epoch 3/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 137ms/step - accuracy: 0.8569 - loss: 0.4260 - val_accuracy: 0.8590 - val_loss: 0.4123 Epoch 4/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 137ms/step - accuracy: 0.8708 - loss: 0.3882 - val_accuracy: 0.8625 - val_loss: 0.3968 Epoch 5/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 139ms/step - accuracy: 0.8722 - loss: 0.3789 - val_accuracy: 0.8625 - val_loss: 0.3930 Epoch 6/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 138ms/step - accuracy: 0.8804 - loss: 0.3550 - val_accuracy: 0.8615 - val_loss: 0.3715 Epoch 7/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 139ms/step - accuracy: 0.8788 - loss: 0.3536 - val_accuracy: 0.8530 - val_loss: 0.4022 Epoch 8/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 137ms/step - accuracy: 0.8847 - loss: 0.3330 - val_accuracy: 0.8665 - val_loss: 0.3688 Epoch 9/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 137ms/step - accuracy: 0.8872 - loss: 0.3294 - val_accuracy: 0.8690 - val_loss: 0.3511 Epoch 10/10 79/79 ━━━━━━━━━━━━━━━━━━━━ 11s 137ms/step - accuracy: 0.8923 - loss: 0.3111 - val_accuracy: 0.8715 - val_loss: 0.3503 Restoring model weights from the end of the best epoch: 10.
4. Model Evaluation:¶
Similar to previous parts, we will evaluate our model's performance using appropriate metrics.
predictions = model.predict(padded_sequences_valid)
valid_preds = np.argmax(predictions, axis=1)
metrics_val.run(valid_labels, valid_preds, "BiLSTM + W2V")
metrics_val.plot()
63/63 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step
Part 4: Model Explainability (LIME / SHAP)¶
As machine learning models become more powerful, they also become more complex and opaque — especially deep learning models like LSTMs or transformer-based architectures. This complexity makes it harder to understand why a model makes a specific prediction.
In this part, we will explore model explainability, a crucial step in building trustworthy and transparent machine learning systems.
4.1 Why Explainability Matters¶
Imagine you’ve built a model that predicts whether a customer is likely to churn, or whether a loan should be approved. Even if your model is 90% accurate, you might be asked:
🤔 "But why did the model make that decision?"
This is where explainability becomes essential.
Why it's important:¶
- Trust: Users are more likely to trust predictions they can understand.
- Debugging: Helps identify spurious correlations or biases in the model.
- Fairness & Ethics: Ensures decisions are not based on sensitive or discriminatory attributes.
- Regulatory Compliance: In some domains (like finance or healthcare), explainability is required by law.
Two Main Categories of Explainability:¶
Global Explanations: Understanding the overall model behavior
Example: Which words generally influence sentiment predictions the most?Local Explanations: Understanding individual predictions
Example: Why was *this* review classified as negative?
We'll focus primarily on local explanations using:
LIME(for simpler models like TF-IDF + Logistic Regression)
Let’s begin by applying LIME to our TF-IDF baseline model.
4.2 🧪 Applying LIME to the TF-IDF Model¶
Now that we understand why explainability matters, let's start by applying LIME (Local Interpretable Model-agnostic Explanations) to our first model:
➡️ A TF-IDF + Logistic Regression pipeline.
LIME works by slightly perturbing the input text and seeing how the model prediction changes.
From this, it builds a local, interpretable surrogate model (like a linear regression) to approximate the complex model's behavior near that input.
We'll explain:
- A single prediction for a text sample
- Which words had the most impact (positive or negative) on the predicted label
grid_search.best_estimator_.steps[1][1].coef_
array([[-0.06284034, -0.03205257, -0.02292822, ..., -0.03027038,
-0.05675394, -0.05984418],
[-0.02164632, -0.00278093, -0.02801081, ..., 0.0435467 ,
0.17139464, 0.13993195],
[-0.08727509, -0.04641301, -0.06808652, ..., 0.02400592,
-0.055042 , -0.03742413],
[ 0.17176176, 0.0812465 , 0.11902554, ..., -0.03728224,
-0.0595987 , -0.04266364]])
from lime.lime_text import LimeTextExplainer
# Fixed version of the TF-IDF LIME explainer function
def explain_tfidf_prediction(text_instance, pipeline, class_names):
# Create a LIME text explainer
explainer = LimeTextExplainer(class_names=class_names)
# Get explanation for the prediction
exp = explainer.explain_instance(
text_instance,
pipeline.predict_proba,
num_features=10,
top_labels=len(class_names) # Explain all classes
)
# Display basic information
print(f"Text: {text_instance}")
pred_class = pipeline.predict([text_instance])[0]
print(f"Predicted class: {class_names[pred_class]}")
# Get probabilities for all classes
probs = pipeline.predict_proba([text_instance])[0]
print("\nClass probabilities:")
for i, class_name in enumerate(class_names):
print(f"{class_name}: {probs[i]:.4f}")
# Create visualization for each class
plt.figure(figsize=(20, 15))
# Get the labels that LIME actually explained
top_labels = exp.available_labels()
for i, label_id in enumerate(top_labels):
plt.subplot(2, 2, i+1)
# Get the explanation for this class
exp_list = exp.as_list(label=label_id)
# Extract words and weights
words = [x[0] for x in exp_list]
weights = [x[1] for x in exp_list]
# Sort for better visualization
pairs = sorted(zip(words, weights), key=lambda x: x[1])
words = [x[0] for x in pairs]
weights = [x[1] for x in pairs]
# Create bar chart
colors = ['red' if w < 0 else 'green' for w in weights]
y_pos = np.arange(len(words))
plt.barh(y_pos, weights, color=colors)
plt.yticks(y_pos, words)
plt.title(f"Explanation for class: {class_names[label_id]}")
plt.axvline(x=0, color='black', linestyle='-', alpha=0.5)
plt.tight_layout()
plt.show()
# Print top contributing words for each class
for label_id in top_labels:
print(f"\nTop features for class: {class_names[label_id]}")
exp_list = exp.as_list(label=label_id)
for word, weight in exp_list:
print(f"{word}: {weight:.4f}")
return exp
# Example text from test set
example_text = test_data[0]
class_names = ['World', 'Sports', 'Business', 'Sci/Tech']
lime_exp_tfidf = explain_tfidf_prediction(example_text, grid_search.best_estimator_, class_names)
Text: Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul. Predicted class: Business Class probabilities: World: 0.2938 Sports: 0.0401 Business: 0.5727 Sci/Tech: 0.0933
Top features for class: Business talks: -0.1690 firm: 0.1081 Federal: 0.1026 pension: 0.0954 parent: 0.0504 Unions: 0.0389 workers: 0.0337 disappointed: -0.0333 say: -0.0236 Mogul: 0.0218 Top features for class: World talks: 0.2220 Federal: -0.0861 firm: -0.0778 pension: -0.0648 parent: -0.0341 say: 0.0317 Fears: 0.0274 Mogul: -0.0207 representing: -0.0183 Unions: -0.0171 Top features for class: Sci/Tech talks: -0.0384 pension: -0.0213 Fears: -0.0165 Unions: -0.0162 workers: -0.0162 parent: -0.0144 Federal: 0.0117 firm: -0.0103 say: 0.0102 disappointed: 0.0067 Top features for class: Sports workers: -0.0287 Federal: -0.0261 firm: -0.0196 say: -0.0162 talks: -0.0127 disappointed: 0.0108 pension: -0.0079 representing: 0.0069 Mogul: 0.0059 Unions: -0.0046
🧠 Interpretation of LIME Explanation¶
Let's break down what LIME revealed about the model's reasoning for this particular prediction.
✅ Predicted Class: Business¶
LIME shows us the top 10 words that contributed positively or negatively to each possible class (Business, World, Sci/Tech, Sports).
💬 For Class: Business¶
- Words like "Federal", "firm", and "pension" have positive weights, meaning they support the Business prediction.
- The word "talks" actually detracts from the Business prediction (negative weight), suggesting it's a bit ambiguous.
🌍 For Class: World¶
- Interestingly, "talks" strongly supports the World class here.
- Other Business-related terms (e.g., "Federal", "firm") detract from a World prediction.
💡 Insights:¶
- Words like "workers", "Unions", and "say" appear across multiple classes with small influence, showing they’re more generic.
- "talks" is context-dependent – LIME helps us disentangle how the same word can shift meaning depending on the rest of the sentence.
🧭 Takeaway: LIME helps us peek inside the model's black box and see which features are driving predictions. It also shows that certain words may support multiple classes, but with different intensities.
Ready to move on? Let’s now explore:
4.3 📊 Comparing Explanations for LSTM and Word2Vec Models¶
➡️ Here, we'll try to interpret more complex models (like LSTM and Word2Vec-based models) using LIME and compare how their reasoning differs from the simpler TF-IDF model.
def prepare_text_for_lstm(text, tokenizer, max_length):
"""Prepare text input for LSTM model"""
from tensorflow.keras.preprocessing.sequence import pad_sequences
sequences = tokenizer.texts_to_sequences([text])
padded_seq = pad_sequences(sequences, maxlen=max_length)
return padded_seq
def lstm_predict_proba(texts):
"""Prediction function for LIME to use with LSTM model"""
result = np.zeros((len(texts), len(class_names)))
for i, text in enumerate(texts):
padded = prepare_text_for_lstm(text, tokenizer, max_length)
pred = model.predict(padded, verbose=0)
result[i] = pred[0]
return result
def explain_lstm_prediction(text_instance, class_names):
# Create a LIME text explainer
explainer = LimeTextExplainer(class_names=class_names)
# Get explanation for the prediction
exp = explainer.explain_instance(
text_instance,
lstm_predict_proba,
num_features=10,
top_labels=len(class_names) # Explain all classes
)
# Display basic information
print(f"Text: {text_instance}")
padded = prepare_text_for_lstm(text_instance, tokenizer, max_length)
prediction = model.predict(padded, verbose=0)
predicted_class = np.argmax(prediction[0])
print(f"Predicted class: {class_names[predicted_class]}")
# Get probabilities for all classes
probs = prediction[0]
print("\nClass probabilities:")
for i, class_name in enumerate(class_names):
print(f"{class_name}: {probs[i]:.4f}")
# Create visualization for each class
plt.figure(figsize=(20, 15))
# Get the labels that LIME actually explained
top_labels = exp.available_labels()
for i, label_id in enumerate(top_labels):
plt.subplot(2, 2, i+1)
# Get the explanation for this class
exp_list = exp.as_list(label=label_id)
# Extract words and weights
words = [x[0] for x in exp_list]
weights = [x[1] for x in exp_list]
# Sort for better visualization
pairs = sorted(zip(words, weights), key=lambda x: x[1])
words = [x[0] for x in pairs]
weights = [x[1] for x in pairs]
# Create bar chart
colors = ['red' if w < 0 else 'green' for w in weights]
y_pos = np.arange(len(words))
plt.barh(y_pos, weights, color=colors)
plt.yticks(y_pos, words)
plt.title(f"Explanation for class: {class_names[label_id]}")
plt.axvline(x=0, color='black', linestyle='-', alpha=0.5)
plt.tight_layout()
plt.show()
# Print top contributing words for each class
for label_id in top_labels:
print(f"\nTop features for class: {class_names[label_id]}")
exp_list = exp.as_list(label=label_id)
for word, weight in exp_list:
print(f"{word}: {weight:.4f}")
return exp
example_text = test_data[0]
model_exp_lime = explain_lstm_prediction(example_text, class_names)
Text: Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul. Predicted class: Sci/Tech Class probabilities: World: 0.2572 Sports: 0.2359 Business: 0.1706 Sci/Tech: 0.3363
Top features for class: Sci/Tech firm: -0.0460 they: 0.0259 parent: -0.0234 talks: -0.0197 at: 0.0181 with: 0.0179 are: 0.0175 after: 0.0146 say: 0.0122 for: 0.0077 Top features for class: World firm: 0.0168 talks: 0.0167 they: -0.0119 after: -0.0117 say: -0.0112 parent: 0.0102 at: -0.0083 are: -0.0061 with: -0.0060 for: -0.0049 Top features for class: Sports Federal: 0.0069 talks: -0.0058 say: 0.0044 T: 0.0030 after: 0.0026 parent: 0.0025 with: 0.0024 N: 0.0016 workers: -0.0013 Unions: -0.0011 Top features for class: Business firm: 0.0283 with: -0.0157 they: -0.0145 are: -0.0123 at: -0.0114 parent: 0.0096 talks: 0.0074 say: -0.0065 after: -0.0065 T: 0.0040
🔍 LIME Explanation for LSTM Model¶
🧠 Sentence Analyzed:¶
"Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."
✅ Predicted Class: Sci/Tech¶
The LSTM picked Sci/Tech, but with only 33.6% probability — barely above the other classes. This is essentially a low-confidence, near-tie prediction.
📊 Class Probabilities:¶
- Sci/Tech: 0.3363
- World: 0.2572
- Sports: 0.2359
- Business: 0.1706
The narrow gap between the top three classes already tells us the model is uncertain about this example.
💡 What LIME Reveals¶
Top Features for Sci/Tech (the predicted class):¶
- Function words like "they", "at", "with", "are", and "after" push the prediction toward Sci/Tech.
- Content words that intuitively belong to Business or World — "firm", "parent", "talks" — actually push against Sci/Tech.
- In other words: the model lands on Sci/Tech almost by default, not because of any topical evidence.
Across Other Classes:¶
- Most of the same function words ("they", "after", "with") appear as negative contributors to World and Business classes.
- Words like "firm" slightly boost the Business class but are downplayed for Sci/Tech.
- In the Sports class,wWeights are an order of magnitude smaller than for the other classes, the model is clearly not considering Sports seriously.
🧠 Interpretation:¶
Unlike the TF-IDF model, which focused on specific content words, the LSTM model seems to be relying heavily on syntactic or structural features (like function words and word order). This could be due to:
- The sequential nature of LSTM, which captures contextual dependencies
- A possible lack of domain-specific keywords driving this prediction