Tiny Shakespeare RNN (Inspired by Karpathy)¶
In this notebook:
- Load the karpathy/tiny_shakespeare dataset.
- Build a simple character-level RNN (PyTorch) to predict next characters.
- Train for multiple epochs, showing generated text after 10, 20, 30 epochs, etc.
- Display next-token probability distributions.
- Display an activation map for the RNN (hidden states) to see what’s lighting up.
For reference, see Andrej Karpathy’s RNN effectiveness blog post.
# Install dependencies if needed:
# !pip install datasets torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed = 42
torch.manual_seed(seed)
<torch._C.Generator at 0x10bce6950>
1. Load the Tiny Shakespeare Data¶
We get a single text string from the dataset, then build a char-level vocabulary.
tiny_data = load_dataset('karpathy/tiny_shakespeare', )
train_texts = tiny_data['train']['text'] # list of strings
# Typically there's only 1 record with entire text.
all_text = " ".join(train_texts)
print(f"Dataset length in chars: {len(all_text)}")
# Build vocab
chars = sorted(list(set(all_text)))
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size}")
# char to int, int to char
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for ch,i in stoi.items() }
Dataset length in chars: 1003854 Vocabulary size: 65
We’ll define a function to encode sequences of characters to integer IDs and decode back to characters. Then we’ll create train splits. For simplicity, we’ll just keep the entire text as a single sequence, though more sophisticated approaches might do chunking.
def encode_text(text):
return [stoi[ch] for ch in text]
def decode_ids(ids):
return ''.join(itos[i] for i in ids)
data_ids = torch.tensor(encode_text(all_text), dtype=torch.long)
print("Encoded data shape:", data_ids.shape)
# Let's do 90% for train, 10% for val
n = int(0.9*len(data_ids))
train_ids = data_ids[:n]
val_ids = data_ids[n:]
print("train, val shapes:", train_ids.shape, val_ids.shape)
Encoded data shape: torch.Size([1003854]) train, val shapes: torch.Size([903468]) torch.Size([100386])
2. Create a Char-Level RNN Model¶
We’ll do a small 1-layer LSTM or basic RNN that outputs logits over vocab_size
for each step. (You can also do a custom naive RNN if you prefer following Karpathy’s minimal example, but here we can do PyTorch’s built-in so we can easily show the hidden states, etc.)
class CharRNN(nn.Module):
def __init__(self, vocab_size, embed_dim=64, hidden_size=128):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
self.rnn = nn.LSTM(embed_dim, hidden_size, num_layers=1, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden=None):
# x: (batch, seq)
emb = self.embed(x) # (batch, seq, embed_dim)
if hidden is None:
out, hidden = self.rnn(emb)
else:
out, hidden = self.rnn(emb, hidden)
logits = self.fc(out) # (batch, seq, vocab_size)
return logits, hidden
def generate(self, start_char, max_new_tokens=100):
# quick sampling method
self.eval()
with torch.no_grad():
hidden = None
x = torch.tensor([[stoi[start_char]]], dtype=torch.long, device=device)
out_str = [start_char]
for _ in range(max_new_tokens):
logits, hidden = self.forward(x, hidden=hidden)
# take the last time step
last_logits = logits[:,-1,:] # shape (1, vocab_size)
probs = F.softmax(last_logits, dim=-1)
# sample
ix = torch.multinomial(probs, num_samples=1)
x = ix
ch = itos[ix.item()]
out_str.append(ch)
return ''.join(out_str)
def next_token_probs(self, context):
# Returns the next-token prob distribution for a given context string.
self.eval()
with torch.no_grad():
hidden = None
x = torch.tensor([encode_text(context)], dtype=torch.long, device=device)
logits, hidden = self.forward(x)
last_logits = logits[0,-1,:] # shape (vocab_size,)
probs = F.softmax(last_logits, dim=0)
return probs.cpu().numpy()
def get_activations(self, context):
# We'll fetch hidden state after feeding context
self.eval()
with torch.no_grad():
x = torch.tensor([encode_text(context)], dtype=torch.long, device=device)
emb = self.embed(x)
out, hidden = self.rnn(emb)
# out shape: (1, seq_len, hidden_size)
# hidden: tuple((1, batch, hidden_size), (1, batch, hidden_size)) for LSTM
return out.squeeze(0).cpu().numpy() # shape: (seq_len, hidden_size)
2.1 Prepare a batch loader function¶
We’ll define a simple function that, given a chunk of text, returns (X, Y)
where X
is input chars and Y
is next chars.
def get_batch(data_ids, seq_len=128, batch_size=32):
# Random starting indices
ix = np.random.randint(0, len(data_ids) - seq_len - 1, (batch_size,))
Xb = []
Yb = []
for i in ix:
chunk = data_ids[i:i+seq_len]
target = data_ids[i+1:i+seq_len+1]
Xb.append(chunk)
Yb.append(target)
# Use torch.stack to combine the tensors
Xb = torch.stack(Xb)
Yb = torch.stack(Yb)
return Xb, Yb
3. Training the RNN¶
We’ll train for 3000 epochs total, sampling text after training for 50 epochs to show progress. The blog post indicates small RNNs can generate surprisingly coherent text even after a short time.
We also keep the next-token probability map and an activation map for a sample context.
model = CharRNN(vocab_size=vocab_size, embed_dim=64, hidden_size=128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
def compute_loss_fn(logits, targets):
B,S,V = logits.shape
logits_flat = logits.view(B*S, V)
targets_flat = targets.view(B*S)
return F.cross_entropy(logits_flat, targets_flat)
max_epochs = 3000
seq_len = 128
batch_size = 32
losses = []
for epoch in range(0, max_epochs+1):
model.train()
Xb, Yb = get_batch(train_ids, seq_len=seq_len, batch_size=batch_size)
Xb, Yb = Xb.to(device), Yb.to(device)
optimizer.zero_grad()
logits, _ = model(Xb)
loss = compute_loss_fn(logits, Yb)
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch%50==0:
print(f"Epoch {epoch}, loss: {loss.item():.4f}")
# sample text at epochs 50, 100, 150, 200, 250, 300...
if epoch in range(0, 3001, 50):
model.eval()
sample = model.generate(start_char='T', max_new_tokens=200)
print(f"\n=== Sample after epoch {epoch} ===\n{sample}\n===========================\n")
Epoch 0, loss: 4.1908 === Sample after epoch 0 === T!'iEU;kTgUglOvNvkinqP;gRRCzryBWc;UcrRtoYz-JcZ;nbL?uZc TrqoS-Kr sS!v,xszIW?F!Cmz,Zp-& pHugDZok.z-Yf$cbzL$rDaB! ApEgY-S lOQ wHYc?!JTR$pXRT$:qt..e:lFqOFw EHlpXNAHARJUVTmSgDN ',XJd,M!hYyrXRM;oJIlv. .byhy =========================== Epoch 50, loss: 2.4465 === Sample after epoch 50 === T:hen touvlh, Wonorith I fouis kouso nnvo he'n son shofrshen th met aret altour y yor axLl, of fig orerl ee too itisisg waplshathe sove, pnoc, be tof homr itn ;e, toe. A h. EReN thedic fithe for imaus =========================== Epoch 100, loss: 2.2058 === Sample after epoch 100 === Tould be crecke sont stell tuntat breccand nofthak thy book. Foun pont mudio fonce shet sy loost ss of thish vest thee, to-wed yur,; and, forith as seen cumne I it all owull dethe' file.er kev Nome. Th =========================== Epoch 150, loss: 2.0445 === Sample after epoch 150 === Th ring thiph the shell chat kinent, and cakfingh? Pinchabe the have. Shy linve themsand, widing that thin I ham hes kinds OUH Nf rues attion a trirveat theum gekie beef, beat on I beav. LULENMV: Hom =========================== Epoch 200, loss: 2.0038 === Sample after epoch 200 === T! havering deads of collouss rich. COFYILIUS: This to spouce: Rothe suth my parkarys Cove's thun siap chadow'e, with hay greeng? And sten's ance will ic ulefte nocdoutule to blow, had puntorr of ou y =========================== Epoch 250, loss: 1.9360 === Sample after epoch 250 === T-KI Lizend of my now rest a dilk you lingon: To body? WARICHO: We lord: O, father wist, Go ture a pead him comy that wither in these contrurs's! My prise Jdie. Vitine, Or fart at spasicared, Noned hi =========================== Epoch 300, loss: 1.8067 === Sample after epoch 300 === T: If menter a habketh, Catciang, How all thee case There our daghter So what, braind with dondake in stakn. Seiwele, Wo eavem: and JKwarter hom all shall have it twe, tatle with hast. RUCIY: In the =========================== Epoch 350, loss: 1.8255 === Sample after epoch 350 === TEO: By to What apperopsce. LARIO: Vormand then. MENENTIUS: Toath proncomernid, in thou supon I iscgoot word homore him to winsed fonoum is'row--purt there celones you relioker-shalf your low? A till =========================== Epoch 400, loss: 1.7309 === Sample after epoch 400 === Theat, Mull by your me, newed what I havring bose, Mone that sh by Sill. RICHY: Whil deetby of he fuperins is the liencher chad; tite, heard. shim say, lowd? But we ight King and by more 'notle shall =========================== Epoch 450, loss: 1.6591 === Sample after epoch 450 === TV Rombrows knows be be astered with werrand awfores barswaid. Seind I'll there: to though it the vein let. LADY CARWICLIND: Lectife hill for, stis, penwick, And gear? Thy inaturent, we will lears oy =========================== Epoch 500, loss: 1.6838 === Sample after epoch 500 === Thre chanted langbatter? ake hearbly in islence lawing a oul with despand that, sit! CFlinter streacl thee vering of word alouncpedonce, thou then their daues Richt a knong I hord and but come alms the =========================== Epoch 550, loss: 1.6996 === Sample after epoch 550 === TI' she conmend and That there was, is hound tounce her hive of thus veiases, Ate sim I'll beining all not ver weeven, is eary o' tel, lite brage thy, Prainw. KING HENRTY OV: Go, you gold placinity an =========================== Epoch 600, loss: 1.6507 === Sample after epoch 600 === TGBURLIBAN: List the harten bantul, not buttent of he: I dome, aSmind not plasor? LON: O KING EDWARD III:: Thou, my frandighs in heow upon they? VOR: Meace was yet is nease that your ourpers! O mest =========================== Epoch 650, loss: 1.5590 === Sample after epoch 650 === The earcued'd reme say I foll? 'on she know of that countin the party: Crorrain: the onoring figfer's bliing yurses, set theo wons dots the paich.' QUEEN ELIZA: Nowe, sir; she all thy soons the waupy =========================== Epoch 700, loss: 1.6363 === Sample after epoch 700 === Tust and belunfiers up; When it us is thee, father at the? As I hath distabes Hed or would word call, had it man beard-ghel? DUKE OF YORK: Frird: Not Grave; mown and yough youm abuth it spasted there =========================== Epoch 750, loss: 1.6209 === Sample after epoch 750 === TH: Neep. SOMINBET: Come grace! ERCIT: Let hand tentroun where so as this' the draw, It father is menty, thou melin's down, Be uss! o' gall. Schap, we seetland he priece? has that so wops, all peast, =========================== Epoch 800, loss: 1.6257 === Sample after epoch 800 === Ty. BENVOLIO: Ay, she Evile appeaming mine come: But twe lead, sir, loke thee Geem this with shoull Morrart and To may your grace My shold company murtly fold be perseft as swo; Why, we beys this boy. =========================== Epoch 850, loss: 1.6097 === Sample after epoch 850 === TERONTE: What speak a repole-thim contale, and comes, by the perpot poieds: be go you spewamest have meet sich your daster, I did, nack, roy,-- Musting, but where thring him! I cannaly love not I proil =========================== Epoch 900, loss: 1.5910 === Sample after epoch 900 === TPUF Lard: S didelf To us hea, s know, we men, lofe, Gevent the speed it the our serveins its, some mus, My gring poor restaging held he may mighters requress to us. Whenoness the know An as from the k =========================== Epoch 950, loss: 1.5435 === Sample after epoch 950 === TERS O, for prouns too smalt'd my fait, Why, wilt he handward Rome is a griant--it-blood, or hours the king. First Servingmer: Scause hill? ELWARD: No, my vowardles. GXELTER: And to thy trair. Which =========================== Epoch 1000, loss: 1.5303 === Sample after epoch 1000 === That, The peace, beyoud done of play my bryose abours bewill not shapple here; as I may me ruming, That saon, we are us seet thou cannot toldie-tint; and shall been of my foother an incont. We'eves; He =========================== Epoch 1050, loss: 1.5319 === Sample after epoch 1050 === Te prity. OSFORK: No; by roued. KING RICHARD III: Would have wetching after night. Thon the bide. The spoth or my his: No follow fall. MENENIUS: Way away, of the comity Your deatn: Rnet excellfmess =========================== Epoch 1100, loss: 1.6140 === Sample after epoch 1100 === TNuke lady not exfude, puke see, that cheere. This mistare to liek have use Edward wasten; For a noblers. But is timeon'd ments up not; For coundeal: A pray? KING HENRY VI: Hell you are That it; so be =========================== Epoch 1150, loss: 1.5286 === Sample after epoch 1150 === TINGBROKE: Not may the more! So Mantumbatious gentlem! way, Thy wisse 'If Comets than my mongers. PISARD MANNGRY: If even's suldy last Becear off, I waul be'd thy hand; That when free, beat if you are =========================== Epoch 1200, loss: 1.5888 === Sample after epoch 1200 === That way the fair call you speak come to speak? Nor ones The world: Is you have it brist To comfout, Cenuly pitty of Hastinger's faticeer, and you care, how after I, hast and competty to me your lord. =========================== Epoch 1250, loss: 1.5573 === Sample after epoch 1250 === Tcuntuls. I do be by must, feart To me. Whom it with this: We how as your hards, thy duke will be can leases, How, by Kenrel in heir is nofly some hand, you will her is a pards; of his but 'at excust-c =========================== Epoch 1300, loss: 1.5149 === Sample after epoch 1300 === Thus child made. CORIOLANUS: Then you cannot when so we heelf, every are pursed in my levenge't and sad, there streath Mine, it think am assident the victoraty: if entcam, they some, peace. BENVOLIO: =========================== Epoch 1350, loss: 1.5509 === Sample after epoch 1350 === Tuling at up, or Rickle as it with me; Thou, what good many ingain't being pleasure, no counser; the pasting my youstain'd cormost his plain, shis, there is gentrement my follows on eyes to the mottore =========================== Epoch 1400, loss: 1.4928 === Sample after epoch 1400 === To strongs all, Whiise true, thou wiltlcoman and bidds, hearty, and therefore hate, leave be thy resome against cousin our brother, I pot our talous toone prease: and it is do? HERMERBESARD: What iste =========================== Epoch 1450, loss: 1.5000 === Sample after epoch 1450 === Tar naturn that you let a death for slingure and for makes is gentlement's then common, Come? First Murderer: Vincel to the gand King and appand, resure when chands, as God all Pully, bool, and thou t =========================== Epoch 1500, loss: 1.4978 === Sample after epoch 1500 === T: And look but thuson tooble thries, the but my ravoldden in twaths! And An? this is ciling, like pludines The world. Takep me wrath to the him murse. POMPEY: Come, my send of Englentle Now then I am =========================== Epoch 1550, loss: 1.5436 === Sample after epoch 1550 === This moraft, that they for Mead And thou well. O, but thou shave As once not such among thee run, thereance thy will to fair mentlememiet; Should he greath; I will not's his latester slain with honour' =========================== Epoch 1600, loss: 1.5095 === Sample after epoch 1600 === Thol my smon breford IVed play is thy hoee, answer the heard! Provord'd truit rewar: Nor amonty; Who small fill, I tanny it O'erled! My Lord what this maken and to the plastorys of our know breats and =========================== Epoch 1650, loss: 1.5353 === Sample after epoch 1650 === Tice and child; Appey! 'Tswixted my bearned. How to let make me good like me not me to more own asiness. CORIOLANUS: What we new, they now whoy he say! hose beginal was naith your singer to me; Besend =========================== Epoch 1700, loss: 1.4564 === Sample after epoch 1700 === To seep them for years; Who has me not? NoLe I lady: The shen'ly deprive Ournot letted, Endear thee is mother Nor I Give the are-proper hath other, and with sudenity you was it. Your pride her away; h =========================== Epoch 1750, loss: 1.4577 === Sample after epoch 1750 === TERS: Gords have hast, though, that ides with with thy speak to a yous arce, To mible fellessgal bests, if by-condention. PRINCE EDWARD: No, distry majesty, Diding of the Leard Poriuble Duking my fine =========================== Epoch 1800, loss: 1.4972 === Sample after epoch 1800 === To less and biles with hast of my mocess-- In a kind made I find you, With destreath Plate Shall be least my brother beat should condedue heats away To knows it word, and him shame how must not, and wh =========================== Epoch 1850, loss: 1.4385 === Sample after epoch 1850 === TANGBY: And are out to thee, paint again from the thing in fastary; the say. We comes my moie; His ear? AUTOLYCUS: Sorrow not be lingel, wread on Bo. Nussend: No. Hence all was found to tell that woo =========================== Epoch 1900, loss: 1.4748 === Sample after epoch 1900 === Thied have not become all the father it maipent to put earth with allise put mine lady's hurty thou banos never busines quitterce should done his elin. But of is thou will sir, Baporning not thousesses =========================== Epoch 1950, loss: 1.4893 === Sample after epoch 1950 === THBy Lord you to denutencanch a pinco Whem, my bodg! Good to there in bround covervein; He say the gods. If, Tupest strazed do have you dots I frof marry when ever with thyself! Will hell my hands me d =========================== Epoch 2000, loss: 1.4765 === Sample after epoch 2000 === Thy shall may Roman! AUTOLYCUS: Let empland? There Intiling tains not bloods: In hear no changed, and that I am lady's pleasure from top to hear he our sperciat crusners, vance heavencus, Nor for the =========================== Epoch 2050, loss: 1.4833 === Sample after epoch 2050 === Thoist, and lay tho, I'll be on ermid blood, stay-famed to me, The basweef sad work'd with on, tet is to pervany, And sweet them, By you, so we have shall you banirity-hat. Actoer: Why, why make you f =========================== Epoch 2100, loss: 1.4370 === Sample after epoch 2100 === T! Scarth o' now, Put frie's fortummeth of my high and on homes to hidlemnetters is fill, And, is am bening in dageth and our as young and end him. LUCIUS: He carlies I was done to me, For bring chang =========================== Epoch 2150, loss: 1.4534 === Sample after epoch 2150 === TI: Mightress'd to your man read, She seem I will make alasing willy myself. WARWICK: O, since that sir, if this very man comes herefork, He shalt with when thou art but which as sweet beaut, where co =========================== Epoch 2200, loss: 1.4652 === Sample after epoch 2200 === Thie sir, Or in etermy for since to char by secuty to seven other liquit Of God!' Amend, but more tage To it sorrow with born. Away of to England. ISABELLA: 'That pray upon the king. LEONTES: Not of =========================== Epoch 2250, loss: 1.4394 === Sample after epoch 2250 === The need fear; Weht Talus I might doth reason's falling. Fir, shafe, I have forged, in my enemy of Bothand him sin and crotherwer too. THOMAS MIWBR: Now'd it been to two Sicies: upon, I be speak are m =========================== Epoch 2300, loss: 1.4567 === Sample after epoch 2300 === This true, There I can slafe wits re; Of I? Yet-I like in the noble mimer of thy words more to wolkeng forger; And that rootes is end were rue my gracious worture of are thing thereaffer him and hart h =========================== Epoch 2350, loss: 1.4665 === Sample after epoch 2350 === TINGBROK: I do beaught whath defend. LUCIO: How conlie and turn'd, And hath the neyard more! Ireporle no? CAMILLO: I did without the rammeggal princely action: When he have has not ascience; but they =========================== Epoch 2400, loss: 1.4228 === Sample after epoch 2400 === THANRY OF LAD: Up, though not be slew the foot thou arm to kingdous death of all; If Warwick what is not this mown I bements of Leinst, then were all these. POLIXENES: O, good deadness. Third Antir o =========================== Epoch 2450, loss: 1.3804 === Sample after epoch 2450 === TINGBR OF AULEL: If I have an every, And read much standed be far helbed upon off like Chirdly gracious; thou should to like when devise your looker; and let the Lord lommment Bencond it. Clown: No mo =========================== Epoch 2500, loss: 1.4343 === Sample after epoch 2500 === THERS HARBEON: What is masters for sworn, and all the heart with a deet, I no loves and I'll put a divine after That sin. I have things and strike and not me, strong of vitious one giff. KING RICHARD =========================== Epoch 2550, loss: 1.4236 === Sample after epoch 2550 === To do I say, yet, my lord, He should, And such you me out made it, holderly, Samibums, Will not new than hath it slaught as the peace. VOLUMNIA: Traivility. GREGORY: more, not presence our friend Hee =========================== Epoch 2600, loss: 1.4352 === Sample after epoch 2600 === This should yet offen cry me? What I thou westy beg thy princely and renothers, for success: Fefer land their it,--wife, you do speed' the setself abling are design muging in on thy wit these queen,-lo =========================== Epoch 2650, loss: 1.3579 === Sample after epoch 2650 === This: O furfaced. Pompere art. Thine I merton. SICINIUS: You, so lording enemy, ratter to Moncen, She whils six your gracious calreaty nake mises the stonger I resent succossipt that I name. God, defe =========================== Epoch 2700, loss: 1.4649 === Sample after epoch 2700 === TH lay'd encurate Hamb: But lose heart, and postisnit of mastershin, Of our mootus. LORD STANLET: What that, he straights to Romeo the dissestance in this? What, is the lideness Ocl warrion! Wh Camili =========================== Epoch 2750, loss: 1.3758 === Sample after epoch 2750 === Turn. GLOUCESTER: Speak to him in the sword adied cursb, Ond your a proveds; will save you, One, if I must comes me she'er by burn'd of the pursue and for upon with her by dost been unousurence. ? tho =========================== Epoch 2800, loss: 1.4230 === Sample after epoch 2800 === Those wounds away. LUCIO: I partel Chirdly schasters by the wounded with the offigrow and the Boin'd attend me the mind play, noble King are rest. GLOUCESTER: Away, tell he your, been, I dare lameedn =========================== Epoch 2850, loss: 1.4183 === Sample after epoch 2850 === Thite the give; since there can thou shalt lovely dew, Thanks woman, your pale, Le't person from him. KING RICHARD III: Sir I smeets right, For it worship, you, 'twas had were something. KING RICHARD =========================== Epoch 2900, loss: 1.4308 === Sample after epoch 2900 === The rewer, What Henry that honourach for then, care to procliots: Aft will here Chord That I must known heart loys away in put so ophould herp in wruth's upon this herd? See he know'st noble disserbt, =========================== Epoch 2950, loss: 1.4675 === Sample after epoch 2950 === Thas thereof onmils. O, if lies in this liberty, And lay trouble. QUEEN ELIZABETH: The worpen your gracious consmalish and To thost! Comener did! all be them not he hapor, grant home to speak, To time =========================== Epoch 3000, loss: 1.3882 === Sample after epoch 3000 === TH lord, her honour, he hat yet people, Lord to this live, they save me sworn it should hence: These knew, and true arcean of him of it alive a gods had the That to back remauses and late, weary says A ===========================
Quite impressive, at first we have just random characters, but after a few epochs we see that the networks have learned to create bags of characters. Then it really starts to generate coherent words even if the bag of words are not very good. And we see some patterns with capital letters and punctuation. It definitely gets the idea of the character speaking. And it looks to master the punctuation at the end, and even we get some kind of good couple of words in a row.
Plot Training Loss¶
plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()
4. Next-Token Probability Map¶
We'll pick a small context (e.g. 'Thou art'
) and see the predicted distribution for the next character.
1) Function: top5_probs_across_phrase
¶
This function:
- Iterates over a given phrase character by character,
- Feeds the substring to the model,
- Captures the softmax distribution for the next character,
- Extracts the top-5 indices and probabilities for each position,
- Returns a matrix of probabilities and the corresponding top-5 characters.
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import HTML, display
def top5_probs_across_phrase(model, phrase, stoi, itos):
"""
For each character in 'phrase', compute the RNN model's next-token probability
distribution, pick top-5 tokens, store them, and return a matrix for heatmap, etc.
model: A char-level RNN (PyTorch) with a forward(...) that returns (logits, hidden).
phrase: The input string we feed character by character.
stoi: dict mapping chars -> integer IDs.
itos: dict mapping integer IDs -> chars.
Returns:
top5_probs: shape (T, 5), top-5 probabilities for each time step
top5_chars: list of lists of top-5 characters per time step
"""
model.eval()
distributions = []
top5_indices_list = []
hidden = None
with torch.no_grad():
for t in range(len(phrase)):
# feed context up to (t+1)th character
context_sub = phrase[:t+1]
x = torch.tensor([[stoi[ch] for ch in context_sub]], dtype=torch.long, device=device)
logits, hidden = model(x, hidden=hidden)
# take the last output => next char distribution
last_logits = logits[0, -1, :] # shape (vocab_size,)
probs = torch.softmax(last_logits, dim=-1).cpu().numpy()
distributions.append(probs)
# top-5 indices
idx_sorted = np.argsort(-probs)
top5_idx = idx_sorted[:5]
top5_indices_list.append(top5_idx)
# Convert to arrays
distributions = np.array(distributions) # shape (T, vocab_size)
T = len(phrase)
top5_probs = np.zeros((T, 5))
top5_chars = []
for i in range(T):
t5i = top5_indices_list[i]
t5p = distributions[i][t5i]
top5_probs[i,:] = t5p
t5c = [itos[j] for j in t5i]
top5_chars.append(t5c)
return top5_probs, top5_chars
2) Using top5_probs_across_phrase
+ Heatmap¶
- We generate top-5 predictions for each step of a phrase.
- We print them in text form.
- We create a heatmap to visualize how probabilities evolve over each time step.
# Example usage
context_phrase = "The orchard is"
top5_matrix, top5_chs = top5_probs_across_phrase(model, context_phrase, stoi, itos)
# 1) Print top-5 tokens for each step
for i, ch in enumerate(context_phrase):
print(f"\nAfter reading '{context_phrase[:i+1]}' (last char '{ch}'):")
t5c = top5_chs[i]
t5p = top5_matrix[i]
# sort them by descending probability for readability
idx_desc = np.argsort(-t5p)
for rank in range(5):
char_ = t5c[idx_desc[rank]]
prob_ = t5p[idx_desc[rank]]
print(f" '{char_}' -> {prob_:.3f}")
# 2) Heatmap of top-5 probabilities (T x 5)
plt.figure(figsize=(8,6))
sns.heatmap(top5_matrix, annot=False, cmap='Blues')
plt.title(f"Top-5 Probability Heatmap for '{context_phrase}'")
plt.xlabel("Top-5 tokens (not necessarily the same at each step)")
plt.ylabel("Time-step in phrase")
plt.show()
After reading 'T' (last char 'T'): 'h' -> 0.437 'o' -> 0.184 'I' -> 0.073 'H' -> 0.055 'E' -> 0.049 After reading 'Th' (last char 'h'): 'y' -> 0.262 'e' -> 0.221 'i' -> 0.209 'o' -> 0.174 'u' -> 0.078 After reading 'The' (last char 'e'): 'n' -> 0.414 'r' -> 0.232 's' -> 0.181 ' ' -> 0.100 'R' -> 0.022 After reading 'The ' (last char ' '): 'r' -> 0.232 'C' -> 0.091 'm' -> 0.060 's' -> 0.060 't' -> 0.048 After reading 'The o' (last char 'o'): 'r' -> 0.194 't' -> 0.183 'n' -> 0.162 'w' -> 0.137 'f' -> 0.093 After reading 'The or' (last char 'r'): ' ' -> 0.865 'd' -> 0.021 ' ' -> 0.019 'e' -> 0.019 'i' -> 0.011 After reading 'The orc' (last char 'c'): 'e' -> 0.769 'h' -> 0.123 'o' -> 0.078 'u' -> 0.020 'i' -> 0.004 After reading 'The orch' (last char 'h'): 'e' -> 0.382 'u' -> 0.232 'a' -> 0.196 'i' -> 0.116 'o' -> 0.061 After reading 'The orcha' (last char 'a'): 'l' -> 0.416 'n' -> 0.231 'r' -> 0.105 's' -> 0.097 'm' -> 0.036 After reading 'The orchar' (last char 'r'): 'd' -> 0.721 'g' -> 0.053 'r' -> 0.050 'e' -> 0.045 's' -> 0.026 After reading 'The orchard' (last char 'd'): ',' -> 0.304 ' ' -> 0.149 'e' -> 0.129 '.' -> 0.074 ';' -> 0.063 After reading 'The orchard ' (last char ' '): 't' -> 0.171 'i' -> 0.113 'o' -> 0.079 'w' -> 0.077 'a' -> 0.077 After reading 'The orchard i' (last char 'i'): 'n' -> 0.483 's' -> 0.394 't' -> 0.082 'm' -> 0.014 'f' -> 0.013 After reading 'The orchard is' (last char 's'): ' ' -> 0.903 ',' -> 0.037 ' ' -> 0.019 '.' -> 0.017 ';' -> 0.006
We see that sometimes the network is really sure about the next character, and sometimes it is not. Maybe we should look at clear characters to look for potential overfitting.
3) Color Functions for Activation and Probability¶
The next functions color a value from [-1,1] in a blue->white->green scale, and from [0,1] in white->red scale, for probability.
def color_val_for_activation(a_value):
"""
Map activation in [-1, 1] to a color gradient:
blue at -1, white at 0, green at 1.
R=0..0, G=0..255, B=255..0
"""
av = max(-1.0, min(1.0, a_value))
scale = (av + 1.0) * 0.5 # 0 => -1, 1 => +1
g = int(scale * 255)
b = int((1.0 - scale) * 255)
return f"rgb(0,{g},{b})"
def color_val_for_prob(p):
"""
Map probability in [0,1] to a color from white(0) to red(1).
We'll do purely red: white => p=0 => rgb(255,255,255)
red => p=1 => rgb(255,0,0)
"""
p_clamped = max(0.0, min(1.0, p))
r = 255
g = int(255 - 255*p_clamped)
b = int(255 - 255*p_clamped)
return f"rgb({r},{g},{b})"
def visualize_firing_and_guesses(
text,
hidden_vals, # shape (T, hidden_size)
top5_ix, # shape (T, 5)
top5_probs, # shape (T, 5)
encoded_input, # shape (T,)
neuron_idx = 0,
itos = None
):
"""
text: original input string
hidden_vals: (T, hidden_size), hidden states for each time step
top5_ix: (T, 5)
top5_probs: (T, 5)
neuron_idx: which dimension of hidden state to visualize
itos: index->char mapping
"""
T = len(text)
html = []
html.append("<div style='font-family:monospace;'>")
html.append("<table>")
# Row 1: input characters colored by hidden state dimension
html.append("<tr>")
for t in range(T):
char_ = text[t]
a_val = hidden_vals[t, neuron_idx] # e.g. LSTM hidden dimension
ccol = color_val_for_activation(a_val)
html.append(f"<td style='background-color:{ccol}; padding:4px;'>{char_}</td>")
html.append("</tr>")
# Row 2: top-5 guesses
html.append("<tr>")
for t in range(T):
guesses_ix = top5_ix[t]
guesses_pb = top5_probs[t]
cell_lines = []
for i in range(5):
gix = guesses_ix[i]
p_ = guesses_pb[i]
ch_ = itos[gix] if itos else f"{gix}"
ccol = color_val_for_prob(p_)
cell_lines.append(f"<span style='background-color:{ccol}'>{ch_} {p_:.2f}</span>")
cell_html = "<br/>".join(cell_lines)
html.append(f"<td style='vertical-align:top; padding:4px;'>{cell_html}</td>")
html.append("</tr>")
html.append("</table>")
html.append("</div>")
return "".join(html)
# We'll define a hypothetical function that obtains hidden states + top5
# for each step of a phrase. It's similar to 'top5_probs_across_phrase', but also
# we store hidden states. Let's call it `get_firing_and_top5(...)`.
def get_firing_and_top5(model, text, stoi, itos):
"""
Returns:
hidden_vals: shape (T, hidden_size)
top5_ix: (T,5)
top5_probs: (T,5)
"""
model.eval()
hidden_vals = []
top5_ix = []
top5_pb = []
hidden = None
with torch.no_grad():
for t in range(len(text)):
context_sub = text[:t+1]
x = torch.tensor([[stoi[ch] for ch in context_sub]], dtype=torch.long, device=device)
logits, hidden_state = model(x, hidden=hidden)
hidden = hidden_state # LSTM => tuple (h, c)
# Extract hidden vector => shape (1,1,hidden_size) => pick [0,0,:]
# or for a GRU => shape (1,1,hidden_size)
# We'll assume LSTM for example:
h_vec = hidden_state[0][0,0,:].cpu().numpy() # shape (hidden_size,)
hidden_vals.append(h_vec)
# Next char dist
last_logits = logits[0,-1,:]
probs = torch.softmax(last_logits, dim=-1).cpu().numpy()
# top-5
idx_sorted = np.argsort(-probs)
t5i = idx_sorted[:5]
t5p = probs[t5i]
top5_ix.append(t5i)
top5_pb.append(t5p)
hidden_vals = np.stack(hidden_vals, axis=0) # (T, hidden_size)
top5_ix = np.stack(top5_ix, axis=0) # (T, 5)
top5_pb = np.stack(top5_pb, axis=0) # (T, 5)
return hidden_vals, top5_ix, top5_pb
# Now let's do final usage:
my_text = "The orchard is"
hvals, t5ix, t5pb = get_firing_and_top5(model, my_text, stoi, itos)
html_str = visualize_firing_and_guesses(
text=my_text,
hidden_vals=hvals,
top5_ix=t5ix,
top5_probs=t5pb,
encoded_input=[stoi[ch] for ch in my_text],
neuron_idx=10, # e.g. dimension #10
itos=itos
)
display(HTML(html_str))
{char_} | ") html.append("
{cell_html} | ") html.append("
T | h | e | o | r | c | h | a | r | d | i | s | ||
h 0.44 o 0.18 I 0.07 H 0.06 E 0.05 | y 0.26 e 0.22 i 0.21 o 0.17 u 0.08 | n 0.41 r 0.23 s 0.18 0.10 R 0.02 | r 0.23 C 0.09 m 0.06 s 0.06 t 0.05 | r 0.19 t 0.18 n 0.16 w 0.14 f 0.09 | 0.86 d 0.02 0.02 e 0.02 i 0.01 | e 0.77 h 0.12 o 0.08 u 0.02 i 0.00 | e 0.38 u 0.23 a 0.20 i 0.12 o 0.06 | l 0.42 n 0.23 r 0.10 s 0.10 m 0.04 | d 0.72 g 0.05 r 0.05 e 0.05 s 0.03 | , 0.30 0.15 e 0.13 . 0.07 ; 0.06 | t 0.17 i 0.11 o 0.08 w 0.08 a 0.08 | n 0.48 s 0.39 t 0.08 m 0.01 f 0.01 | 0.90 , 0.04 0.02 . 0.02 ; 0.01 |
So yes, we see the difference between some predictions, like for instance it's quite sure about d
after a
but not about h
after c
.
# Now let's do final usage:
my_text = "Have you heard"
hvals, t5ix, t5pb = get_firing_and_top5(model, my_text, stoi, itos)
html_str = visualize_firing_and_guesses(
text=my_text,
hidden_vals=hvals,
top5_ix=t5ix,
top5_probs=t5pb,
encoded_input=[stoi[ch] for ch in my_text],
neuron_idx=10, # e.g. dimension #10
itos=itos
)
display(HTML(html_str))
H | a | v | e | y | o | u | h | e | a | r | d | ||
A 0.30 e 0.27 a 0.13 i 0.06 E 0.05 | v 0.48 t 0.25 n 0.08 d 0.07 s 0.05 | e 0.88 i 0.12 y 0.00 o 0.00 a 0.00 | 0.85 s 0.04 0.03 , 0.02 n 0.01 | t 0.24 m 0.13 y 0.10 a 0.07 h 0.05 | o 0.85 e 0.12 i 0.03 a 0.00 u 0.00 | u 0.99 n 0.00 w 0.00 r 0.00 t 0.00 | r 0.46 0.32 , 0.09 . 0.04 0.03 | t 0.26 a 0.10 w 0.07 b 0.07 n 0.07 | a 0.51 e 0.23 i 0.16 o 0.09 u 0.01 | a 0.40 r 0.32 n 0.10 0.07 l 0.05 | r 0.71 v 0.15 d 0.10 t 0.02 l 0.01 | 0.53 s 0.10 d 0.10 t 0.09 , 0.05 | 0.50 , 0.18 0.09 . 0.05 s 0.03 |
And here we see that actually the netwrok is more sure about the next character. It may be quite correlated with the frequencies of the characters in the training set. This is something we should look at.
5. Activation Map¶
We’ll pick the same context 'Thou art'
and retrieve the RNN hidden states to see which dimensions are lighting up. We’ll do a simple heatmap with matplotlib.
acts = model.get_activations("Thou art") # shape (seq_len, hidden_size)
print("Activations shape:", acts.shape)
plt.figure(figsize=(10,3))
plt.imshow(acts.T, aspect='auto', cmap='viridis')
plt.colorbar()
plt.title("RNN Hidden State Activations (each row = one hidden dim)")
plt.xlabel("Time Step (for each char in 'Thou art')")
plt.ylabel("Hidden Dimension")
plt.show()
Activations shape: (8, 128)
From the above heatmap, you can see how some hidden dimensions are more active (bright) or less active (dark) at each time step. That shows what the RNN is focusing on, reminiscent of Karpathy’s blog post approach.
Conclusion¶
- We loaded
tiny_shakespeare
, built a small character-level LSTM in PyTorch. - We trained for multiple epochs, showing sample text at epochs 10, 20, 30.
- We displayed next-token probability distribution for a chosen context.
- We plotted the RNN hidden state activation map to see which neurons activate.
Following Karpathy’s blog post, even a small RNN can learn to generate Shakespeare-like text with enough training epochs and the right hyperparameters. This approach demonstrates how each additional epoch yields more coherent text.