You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

1878 lines
101 KiB

{
"cells": [
{
"cell_type": "markdown",
"id": "622dfcd6",
"metadata": {},
"source": [
"# Group 09 - Automatic Sentence Completion for PubMed"
]
},
{
"cell_type": "markdown",
"id": "5e4cec3c",
"metadata": {},
"source": [
"### Authors:\n",
"- Constantin Fuerst\t\n",
"- Leonard Starke"
]
},
{
"cell_type": "markdown",
"id": "806cfb27",
"metadata": {},
"source": [
"### link to \"Attention is All You Need\" paper describing transformer models"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe862072",
"metadata": {},
"outputs": [],
"source": [
"https://arxiv.org/pdf/1706.03762.pdf"
]
},
{
"cell_type": "markdown",
"id": "fa161b1b",
"metadata": {},
"source": [
"### load query data from text file "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e1912a79",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2023-01-18 15:48:24-- https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download\n",
"Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'\n",
"Resolving cloud.constantin-fuerst.com (cloud.constantin-fuerst.com)... 95.91.21.14\n",
"Connecting to cloud.constantin-fuerst.com (cloud.constantin-fuerst.com)|95.91.21.14|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1100551 (1.0M) [text/plain]\n",
"Saving to: ‘pubmed-query.txt’\n",
"\n",
"pubmed-query.txt 100%[===================>] 1.05M 1.61MB/s in 0.7s \n",
"\n",
"2023-01-18 15:48:25 (1.61 MB/s) - ‘pubmed-query.txt’ saved [1100551/1100551]\n",
"\n"
]
}
],
"source": [
"!wget https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download -O pubmed-query.txt"
]
},
{
"cell_type": "markdown",
"id": "da068411",
"metadata": {},
"source": [
"### import modules used for parsing query data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c10bc5a8",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" from Bio import Medline\n",
"except: \n",
" !pip install Bio\n",
" from Bio import Medline"
]
},
{
"cell_type": "markdown",
"id": "7bf15c30",
"metadata": {},
"source": [
"### define function for loading the papers from PubMed database"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "adfb256a",
"metadata": {},
"outputs": [],
"source": [
"def getPapers(filename):\n",
" pubmed_query = open(filename, encoding='utf-8')\n",
" records = Medline.parse(pubmed_query)\n",
" return list(records)"
]
},
{
"cell_type": "markdown",
"id": "46bc6298",
"metadata": {},
"source": [
"### Verify that its working"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "00481ec9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Got 150000 records from the query text file\n"
]
}
],
"source": [
"max_records = 150000\n",
"records = getPapers(\"pubmed-query.txt\")\n",
"records = records[:min(max_records, len(records))]\n",
"print(f\"Got {len(records)} records from the query text file\")"
]
},
{
"cell_type": "markdown",
"id": "b67747c6",
"metadata": {},
"source": [
"### Now extract abstracts from records"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dcf5c217",
"metadata": {},
"outputs": [],
"source": [
"r_abstracts = []\n",
"for r in records:\n",
" if not (r.get('AB') is None):\n",
" r_abstracts.append(r['AB'])"
]
},
{
"cell_type": "markdown",
"id": "e309f6fe",
"metadata": {},
"source": [
"### Now import torch modules needed to load the data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c3199444",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hein/.local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"try:\n",
" import torch\n",
" from torch.utils.data import Dataset \n",
" from torchtext.data import get_tokenizer\n",
"except:\n",
" !pip --default-timeout=1000 install torch\n",
" !pip --default-timeout=1000 install torchtext\n",
" import torch\n",
" from torch.utils.data import Dataset \n",
" from torchtext.data import get_tokenizer"
]
},
{
"cell_type": "markdown",
"id": "5b4007e8",
"metadata": {},
"source": [
"### Import numpy"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "daca9db6",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" import numpy as np\n",
"except:\n",
" !pip install numpy\n",
" import numpy as np\n"
]
},
{
"cell_type": "markdown",
"id": "683ed2fc",
"metadata": {},
"source": [
"### import math module"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8d2312db",
"metadata": {},
"outputs": [],
"source": [
"import math"
]
},
{
"cell_type": "markdown",
"id": "4df1e449",
"metadata": {},
"source": [
"### define token iterators"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3f23404d",
"metadata": {},
"outputs": [],
"source": [
"train_size = math.floor(len(r_abstracts) * 0.75)\n",
"val_size = math.floor(len(r_abstracts) * 0.125)\n",
"test_size = math.floor(len(r_abstracts) * 0.125)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8a128d3c",
"metadata": {},
"outputs": [],
"source": [
"def train_abstract_iter():\n",
" for abstract in r_abstracts[:train_size]:\n",
" yield abstract"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "97e89986",
"metadata": {},
"outputs": [],
"source": [
"def val_abstract_iter():\n",
" for abstract in r_abstracts[(train_size + 1):(train_size + val_size)]:\n",
" yield abstract"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0d6e89c4",
"metadata": {},
"outputs": [],
"source": [
"def test_abstract_iter():\n",
" for abstract in r_abstracts[(train_size + val_size + 1): (train_size + val_size + test_size)]:\n",
" yield abstract"
]
},
{
"cell_type": "markdown",
"id": "e5e9c5a2",
"metadata": {},
"source": [
"### define Tokenize function"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "0bdbc40a",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = get_tokenizer(\"basic_english\")\n",
"def tokenize_abstract_iter():\n",
" for abstract in r_abstracts:\n",
" yield tokenizer(abstract)"
]
},
{
"cell_type": "markdown",
"id": "37da40bb",
"metadata": {},
"source": [
"### Map every word to an id to store inside torch tensor"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a438ab1f",
"metadata": {},
"outputs": [],
"source": [
"from torchtext.vocab import build_vocab_from_iterator\n",
"token_generator = tokenize_abstract_iter()\n",
"vocab = build_vocab_from_iterator(token_generator, specials=['<unk>'])\n",
"vocab.set_default_index(vocab['<unk>'])\n"
]
},
{
"cell_type": "markdown",
"id": "221bdc48",
"metadata": {},
"source": [
"### now convert to tensor\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "0e5bc361",
"metadata": {},
"outputs": [],
"source": [
"def data_process(tokens_iter):\n",
" \"\"\"Converts raw text into a flat Tensor.\"\"\"\n",
" data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in tokens_iter]\n",
" return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "dfd7400d",
"metadata": {},
"outputs": [],
"source": [
"train_generator = train_abstract_iter()\n",
"val_generator = val_abstract_iter()\n",
"test_generator = test_abstract_iter()\n",
"train_data = data_process(train_generator)\n",
"val_data = data_process(val_generator)\n",
"test_data = data_process(test_generator)"
]
},
{
"cell_type": "markdown",
"id": "c49a2734",
"metadata": {},
"source": [
"### check gpu"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "c155ee31",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "79b2d248",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device"
]
},
{
"cell_type": "markdown",
"id": "2150ba71",
"metadata": {},
"source": [
"### define model"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a33d722f",
"metadata": {},
"outputs": [],
"source": [
"from typing import Tuple\n",
"\n",
"from torch import nn, Tensor\n",
"import torch.nn.functional as F\n",
"from torch.nn import TransformerEncoder, TransformerEncoderLayer\n",
"from torch.utils.data import dataset\n",
"\n",
"class TransformerModel(nn.Module):\n",
"\n",
" def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,\n",
" nlayers: int, dropout: float = 0.5):\n",
" super().__init__()\n",
" self.model_type = 'Transformer'\n",
" self.pos_encoder = PositionalEncoding(d_model, dropout)\n",
" encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)\n",
" self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n",
" self.encoder = nn.Embedding(ntoken, d_model)\n",
" self.d_model = d_model\n",
" self.decoder = nn.Linear(d_model, ntoken)\n",
"\n",
" self.init_weights()\n",
"\n",
" def init_weights(self) -> None:\n",
" initrange = 0.1\n",
" self.encoder.weight.data.uniform_(-initrange, initrange)\n",
" self.decoder.bias.data.zero_()\n",
" self.decoder.weight.data.uniform_(-initrange, initrange)\n",
"\n",
" def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:\n",
" \"\"\"\n",
" Args:\n",
" src: Tensor, shape [seq_len, batch_size]\n",
" src_mask: Tensor, shape [seq_len, seq_len]\n",
"\n",
" Returns:\n",
" output Tensor of shape [seq_len, batch_size, ntoken]\n",
" \"\"\"\n",
" src = self.encoder(src) * math.sqrt(self.d_model)\n",
" src = self.pos_encoder(src)\n",
" output = self.transformer_encoder(src, src_mask)\n",
" output = self.decoder(output)\n",
" return output\n",
"\n",
"\n",
"def generate_square_subsequent_mask(sz: int) -> Tensor:\n",
" \"\"\"Generates an upper-triangular matrix of -inf, with zeros on diag.\"\"\"\n",
" return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)"
]
},
{
"cell_type": "markdown",
"id": "23268efe",
"metadata": {},
"source": [
"### define pos encoder"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "c2f6d33b",
"metadata": {},
"outputs": [],
"source": [
"class PositionalEncoding(nn.Module):\n",
"\n",
" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):\n",
" super().__init__()\n",
" self.dropout = nn.Dropout(p=dropout)\n",
"\n",
" position = torch.arange(max_len).unsqueeze(1)\n",
" div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n",
" pe = torch.zeros(max_len, 1, d_model)\n",
" pe[:, 0, 0::2] = torch.sin(position * div_term)\n",
" pe[:, 0, 1::2] = torch.cos(position * div_term)\n",
" self.register_buffer('pe', pe)\n",
"\n",
" def forward(self, x: Tensor) -> Tensor:\n",
" \"\"\"\n",
" Args:\n",
" x: Tensor, shape [seq_len, batch_size, embedding_dim]\n",
" \"\"\"\n",
" x = x + self.pe[:x.size(0)]\n",
" return self.dropout(x)\n"
]
},
{
"cell_type": "markdown",
"id": "306352f5",
"metadata": {},
"source": [
"### define function to create batches of data and create batches"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "9e184841",
"metadata": {},
"outputs": [],
"source": [
"def batchify(data: Tensor, bsz: int) -> Tensor:\n",
" \"\"\"Divides the data into bsz separate sequences, removing extra elements\n",
" that wouldn't cleanly fit.\n",
"\n",
" Args:\n",
" data: Tensor, shape [N]\n",
" bsz: int, batch size\n",
"\n",
" Returns:\n",
" Tensor of shape [N // bsz, bsz]\n",
" \"\"\"\n",
" seq_len = data.size(0) // bsz\n",
" data = data[:seq_len * bsz]\n",
" data = data.view(bsz, seq_len).t().contiguous()\n",
" return data.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "a4def1ac",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 20\n",
"eval_batch_size = 10\n",
"train_data = batchify(train_data, batch_size) # shape [seq_len, batch_size]\n",
"val_data = batchify(val_data, eval_batch_size)\n",
"test_data = batchify(test_data, eval_batch_size)"
]
},
{
"cell_type": "markdown",
"id": "c658cb42",
"metadata": {},
"source": [
"### define function to get batch"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "4ab5b8fd",
"metadata": {},
"outputs": [],
"source": [
"bptt = 35\n",
"def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:\n",
" \"\"\"\n",
" Args:\n",
" source: Tensor, shape [full_seq_len, batch_size]\n",
" i: int\n",
"\n",
" Returns:\n",
" tuple (data, target), where data has shape [seq_len, batch_size] and\n",
" target has shape [seq_len * batch_size]\n",
" \"\"\"\n",
" seq_len = min(bptt, len(source) - 1 - i)\n",
" data = source[i:i+seq_len]\n",
" target = source[i+1:i+1+seq_len].reshape(-1)\n",
" return data, target"
]
},
{
"cell_type": "markdown",
"id": "d6392484",
"metadata": {},
"source": [
"### define parameters and init model"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "c53764da",
"metadata": {},
"outputs": [],
"source": [
"ntokens = len(vocab) # size of vocabulary\n",
"emsize = 200 # embedding dimension\n",
"d_hid = 200 # dimension of the feedforward network model in nn.TransformerEncoder\n",
"nlayers = 2 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder\n",
"nhead = 2 # number of heads in nn.MultiheadAttention\n",
"dropout = 0.2 # dropout probability\n",
"model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)"
]
},
{
"cell_type": "markdown",
"id": "7fb67d72",
"metadata": {},
"source": [
"### init optimizer, loss, scheduler etc."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "ddaa1d64",
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"import time\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"lr = 5.0 # learning rate\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)"
]
},
{
"cell_type": "markdown",
"id": "dda19446",
"metadata": {},
"source": [
"### define train function"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "50ab3fb6",
"metadata": {},
"outputs": [],
"source": [
"def train(model: nn.Module) -> None:\n",
" model.train() # turn on train mode\n",
" total_loss = 0.\n",
" log_interval = 200\n",
" start_time = time.time()\n",
" src_mask = generate_square_subsequent_mask(bptt).to(device)\n",
"\n",
" num_batches = len(train_data) // bptt\n",
" for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n",
" data, targets = get_batch(train_data, i)\n",
" seq_len = data.size(0)\n",
" if seq_len != bptt: # only on last batch\n",
" src_mask = src_mask[:seq_len, :seq_len]\n",
" output = model(data, src_mask)\n",
" loss = criterion(output.view(-1, ntokens), targets)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
" if batch % log_interval == 0 and batch > 0:\n",
" lr = scheduler.get_last_lr()[0]\n",
" ms_per_batch = (time.time() - start_time) * 1000 / log_interval\n",
" cur_loss = total_loss / log_interval\n",
" ppl = math.exp(cur_loss)\n",
" print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '\n",
" f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n",
" f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n",
" total_loss = 0\n",
" start_time = time.time()"
]
},
{
"cell_type": "markdown",
"id": "9756c092",
"metadata": {},
"source": [
"### define evaluate function"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "3d179bb0",
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n",
" model.eval() # turn on evaluation mode\n",
" total_loss = 0.\n",
" src_mask = generate_square_subsequent_mask(bptt).to(device)\n",
" with torch.no_grad():\n",
" for i in range(0, eval_data.size(0) - 1, bptt):\n",
" data, targets = get_batch(eval_data, i)\n",
" seq_len = data.size(0)\n",
" if seq_len != bptt:\n",
" src_mask = src_mask[:seq_len, :seq_len]\n",
" output = model(data, src_mask)\n",
" output_flat = output.view(-1, ntokens)\n",
" total_loss += seq_len * criterion(output_flat, targets).item()\n",
" return total_loss / (len(eval_data) - 1)"
]
},
{
"cell_type": "markdown",
"id": "5a959f09",
"metadata": {},
"source": [
"### now we can start training the model while saving best one"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "09c4d4ce",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 1 | 200/13484 batches | lr 5.00 | ms/batch 116.24 | loss 9.27 | ppl 10651.55\n",
"| epoch 1 | 400/13484 batches | lr 5.00 | ms/batch 114.02 | loss 7.49 | ppl 1787.62\n",
"| epoch 1 | 600/13484 batches | lr 5.00 | ms/batch 114.33 | loss 6.83 | ppl 923.44\n",
"| epoch 1 | 800/13484 batches | lr 5.00 | ms/batch 114.54 | loss 6.54 | ppl 693.98\n",
"| epoch 1 | 1000/13484 batches | lr 5.00 | ms/batch 114.73 | loss 6.33 | ppl 563.29\n",
"| epoch 1 | 1200/13484 batches | lr 5.00 | ms/batch 114.85 | loss 6.18 | ppl 485.05\n",
"| epoch 1 | 1400/13484 batches | lr 5.00 | ms/batch 114.91 | loss 6.09 | ppl 440.69\n",
"| epoch 1 | 1600/13484 batches | lr 5.00 | ms/batch 115.00 | loss 6.06 | ppl 428.38\n",
"| epoch 1 | 1800/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.98 | ppl 397.07\n",
"| epoch 1 | 2000/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.91 | ppl 369.13\n",
"| epoch 1 | 2200/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.89 | ppl 360.14\n",
"| epoch 1 | 2400/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.83 | ppl 341.10\n",
"| epoch 1 | 2600/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.78 | ppl 322.33\n",
"| epoch 1 | 2800/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.80 | ppl 329.27\n",
"| epoch 1 | 3000/13484 batches | lr 5.00 | ms/batch 115.12 | loss 5.77 | ppl 321.64\n",
"| epoch 1 | 3200/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.71 | ppl 303.37\n",
"| epoch 1 | 3400/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.74 | ppl 311.04\n",
"| epoch 1 | 3600/13484 batches | lr 5.00 | ms/batch 115.15 | loss 5.70 | ppl 299.44\n",
"| epoch 1 | 3800/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.68 | ppl 292.67\n",
"| epoch 1 | 4000/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.59 | ppl 268.70\n",
"| epoch 1 | 4200/13484 batches | lr 5.00 | ms/batch 115.19 | loss 5.62 | ppl 275.23\n",
"| epoch 1 | 4400/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.63 | ppl 277.51\n",
"| epoch 1 | 4600/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.66 | ppl 286.99\n",
"| epoch 1 | 4800/13484 batches | lr 5.00 | ms/batch 115.30 | loss 5.62 | ppl 276.08\n",
"| epoch 1 | 5000/13484 batches | lr 5.00 | ms/batch 115.15 | loss 5.61 | ppl 272.68\n",
"| epoch 1 | 5200/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.59 | ppl 268.83\n",
"| epoch 1 | 5400/13484 batches | lr 5.00 | ms/batch 115.29 | loss 5.55 | ppl 257.80\n",
"| epoch 1 | 5600/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.57 | ppl 261.32\n",
"| epoch 1 | 5800/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.55 | ppl 257.06\n",
"| epoch 1 | 6000/13484 batches | lr 5.00 | ms/batch 115.26 | loss 5.56 | ppl 259.08\n",
"| epoch 1 | 6200/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.57 | ppl 262.89\n",
"| epoch 1 | 6400/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.54 | ppl 254.66\n",
"| epoch 1 | 6600/13484 batches | lr 5.00 | ms/batch 115.27 | loss 5.57 | ppl 263.01\n",
"| epoch 1 | 6800/13484 batches | lr 5.00 | ms/batch 115.21 | loss 5.51 | ppl 246.13\n",
"| epoch 1 | 7000/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.57 | ppl 261.50\n",
"| epoch 1 | 7200/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.51 | ppl 247.48\n",
"| epoch 1 | 7400/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.50 | ppl 245.45\n",
"| epoch 1 | 7600/13484 batches | lr 5.00 | ms/batch 115.26 | loss 5.51 | ppl 247.79\n",
"| epoch 1 | 7800/13484 batches | lr 5.00 | ms/batch 115.27 | loss 5.50 | ppl 245.74\n",
"| epoch 1 | 8000/13484 batches | lr 5.00 | ms/batch 115.33 | loss 5.48 | ppl 240.49\n",
"| epoch 1 | 8200/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.48 | ppl 238.87\n",
"| epoch 1 | 8400/13484 batches | lr 5.00 | ms/batch 115.27 | loss 5.49 | ppl 241.45\n",
"| epoch 1 | 8600/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.47 | ppl 236.88\n",
"| epoch 1 | 8800/13484 batches | lr 5.00 | ms/batch 115.28 | loss 5.47 | ppl 236.31\n",
"| epoch 1 | 9000/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.48 | ppl 240.63\n",
"| epoch 1 | 9200/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.48 | ppl 239.53\n",
"| epoch 1 | 9400/13484 batches | lr 5.00 | ms/batch 115.29 | loss 5.48 | ppl 238.75\n",
"| epoch 1 | 9600/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.43 | ppl 229.14\n",
"| epoch 1 | 9800/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.42 | ppl 226.49\n",
"| epoch 1 | 10000/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.47 | ppl 236.79\n",
"| epoch 1 | 10200/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.41 | ppl 223.98\n",
"| epoch 1 | 10400/13484 batches | lr 5.00 | ms/batch 115.16 | loss 5.39 | ppl 219.63\n",
"| epoch 1 | 10600/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.42 | ppl 225.37\n",
"| epoch 1 | 10800/13484 batches | lr 5.00 | ms/batch 115.30 | loss 5.45 | ppl 232.44\n",
"| epoch 1 | 11000/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.45 | ppl 232.12\n",
"| epoch 1 | 11200/13484 batches | lr 5.00 | ms/batch 115.21 | loss 5.43 | ppl 228.71\n",
"| epoch 1 | 11400/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.38 | ppl 216.73\n",
"| epoch 1 | 11600/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.41 | ppl 222.68\n",
"| epoch 1 | 11800/13484 batches | lr 5.00 | ms/batch 115.28 | loss 5.39 | ppl 218.39\n",
"| epoch 1 | 12000/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.44 | ppl 229.94\n",
"| epoch 1 | 12200/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.36 | ppl 213.26\n",
"| epoch 1 | 12400/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.38 | ppl 217.41\n",
"| epoch 1 | 12600/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.40 | ppl 222.35\n",
"| epoch 1 | 12800/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.41 | ppl 224.63\n",
"| epoch 1 | 13000/13484 batches | lr 5.00 | ms/batch 115.29 | loss 5.40 | ppl 220.79\n",
"| epoch 1 | 13200/13484 batches | lr 5.00 | ms/batch 115.16 | loss 5.41 | ppl 223.58\n",
"| epoch 1 | 13400/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.42 | ppl 225.49\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 1 | time: 1625.43s | valid loss 5.35 | valid ppl 210.93\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 2 | 200/13484 batches | lr 4.75 | ms/batch 115.84 | loss 5.44 | ppl 229.80\n",
"| epoch 2 | 400/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.38 | ppl 216.74\n",
"| epoch 2 | 600/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.35 | ppl 211.15\n",
"| epoch 2 | 800/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.37 | ppl 215.74\n",
"| epoch 2 | 1000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.35 | ppl 210.96\n",
"| epoch 2 | 1200/13484 batches | lr 4.75 | ms/batch 115.17 | loss 5.33 | ppl 207.12\n",
"| epoch 2 | 1400/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.34 | ppl 208.70\n",
"| epoch 2 | 1600/13484 batches | lr 4.75 | ms/batch 115.14 | loss 5.36 | ppl 212.80\n",
"| epoch 2 | 1800/13484 batches | lr 4.75 | ms/batch 115.11 | loss 5.35 | ppl 209.96\n",
"| epoch 2 | 2000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.32 | ppl 203.54\n",
"| epoch 2 | 2200/13484 batches | lr 4.75 | ms/batch 115.15 | loss 5.33 | ppl 205.82\n",
"| epoch 2 | 2400/13484 batches | lr 4.75 | ms/batch 115.23 | loss 5.34 | ppl 208.95\n",
"| epoch 2 | 2600/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.29 | ppl 199.16\n",
"| epoch 2 | 2800/13484 batches | lr 4.75 | ms/batch 115.16 | loss 5.34 | ppl 208.19\n",
"| epoch 2 | 3000/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.33 | ppl 205.88\n",
"| epoch 2 | 3200/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.29 | ppl 198.11\n",
"| epoch 2 | 3400/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.37 | ppl 214.29\n",
"| epoch 2 | 3600/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.31 | ppl 202.72\n",
"| epoch 2 | 3800/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.32 | ppl 203.84\n",
"| epoch 2 | 4000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.24 | ppl 189.14\n",
"| epoch 2 | 4200/13484 batches | lr 4.75 | ms/batch 115.15 | loss 5.28 | ppl 196.95\n",
"| epoch 2 | 4400/13484 batches | lr 4.75 | ms/batch 115.17 | loss 5.29 | ppl 198.84\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 2 | 4600/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.35 | ppl 210.15\n",
"| epoch 2 | 4800/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.32 | ppl 204.37\n",
"| epoch 2 | 5000/13484 batches | lr 4.75 | ms/batch 115.29 | loss 5.33 | ppl 205.42\n",
"| epoch 2 | 5200/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.31 | ppl 201.44\n",
"| epoch 2 | 5400/13484 batches | lr 4.75 | ms/batch 115.23 | loss 5.30 | ppl 200.48\n",
"| epoch 2 | 5600/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.29 | ppl 197.76\n",
"| epoch 2 | 5800/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.34 | ppl 207.65\n",
"| epoch 2 | 6000/13484 batches | lr 4.75 | ms/batch 115.11 | loss 5.32 | ppl 204.89\n",
"| epoch 2 | 6200/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.34 | ppl 209.24\n",
"| epoch 2 | 6400/13484 batches | lr 4.75 | ms/batch 115.14 | loss 5.31 | ppl 201.48\n",
"| epoch 2 | 6600/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.36 | ppl 212.87\n",
"| epoch 2 | 6800/13484 batches | lr 4.75 | ms/batch 115.13 | loss 5.29 | ppl 198.41\n",
"| epoch 2 | 7000/13484 batches | lr 4.75 | ms/batch 115.16 | loss 5.35 | ppl 211.39\n",
"| epoch 2 | 7200/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.30 | ppl 199.94\n",
"| epoch 2 | 7400/13484 batches | lr 4.75 | ms/batch 115.11 | loss 5.30 | ppl 200.81\n",
"| epoch 2 | 7600/13484 batches | lr 4.75 | ms/batch 115.31 | loss 5.35 | ppl 211.20\n",
"| epoch 2 | 7800/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.31 | ppl 201.93\n",
"| epoch 2 | 8000/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.29 | ppl 198.24\n",
"| epoch 2 | 8200/13484 batches | lr 4.75 | ms/batch 115.14 | loss 5.27 | ppl 194.75\n",
"| epoch 2 | 8400/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.29 | ppl 198.48\n",
"| epoch 2 | 8600/13484 batches | lr 4.75 | ms/batch 115.13 | loss 5.29 | ppl 198.11\n",
"| epoch 2 | 8800/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.34 | ppl 207.62\n",
"| epoch 2 | 9000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.33 | ppl 205.55\n",
"| epoch 2 | 9200/13484 batches | lr 4.75 | ms/batch 115.27 | loss 5.33 | ppl 206.24\n",
"| epoch 2 | 9400/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.31 | ppl 201.81\n",
"| epoch 2 | 9600/13484 batches | lr 4.75 | ms/batch 115.25 | loss 5.29 | ppl 198.63\n",
"| epoch 2 | 9800/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.26 | ppl 192.87\n",
"| epoch 2 | 10000/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.30 | ppl 199.77\n",
"| epoch 2 | 10200/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.25 | ppl 191.30\n",
"| epoch 2 | 10400/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.22 | ppl 184.78\n",
"| epoch 2 | 10600/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.27 | ppl 194.07\n",
"| epoch 2 | 10800/13484 batches | lr 4.75 | ms/batch 115.23 | loss 5.30 | ppl 200.53\n",
"| epoch 2 | 11000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.29 | ppl 198.68\n",
"| epoch 2 | 11200/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.28 | ppl 196.43\n",
"| epoch 2 | 11400/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.23 | ppl 186.61\n",
"| epoch 2 | 11600/13484 batches | lr 4.75 | ms/batch 115.13 | loss 5.27 | ppl 195.11\n",
"| epoch 2 | 11800/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.23 | ppl 186.19\n",
"| epoch 2 | 12000/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.31 | ppl 202.19\n",
"| epoch 2 | 12200/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.22 | ppl 184.46\n",
"| epoch 2 | 12400/13484 batches | lr 4.75 | ms/batch 115.32 | loss 5.23 | ppl 187.26\n",
"| epoch 2 | 12600/13484 batches | lr 4.75 | ms/batch 115.31 | loss 5.25 | ppl 189.65\n",
"| epoch 2 | 12800/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.28 | ppl 196.25\n",
"| epoch 2 | 13000/13484 batches | lr 4.75 | ms/batch 115.35 | loss 5.28 | ppl 196.31\n",
"| epoch 2 | 13200/13484 batches | lr 4.75 | ms/batch 115.32 | loss 5.28 | ppl 195.61\n",
"| epoch 2 | 13400/13484 batches | lr 4.75 | ms/batch 115.27 | loss 5.28 | ppl 195.80\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 2 | time: 1625.71s | valid loss 5.24 | valid ppl 188.48\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 3 | 200/13484 batches | lr 4.51 | ms/batch 115.84 | loss 5.32 | ppl 205.41\n",
"| epoch 3 | 400/13484 batches | lr 4.51 | ms/batch 115.17 | loss 5.28 | ppl 195.56\n",
"| epoch 3 | 600/13484 batches | lr 4.51 | ms/batch 115.12 | loss 5.22 | ppl 184.23\n",
"| epoch 3 | 800/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.23 | ppl 187.41\n",
"| epoch 3 | 1000/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.23 | ppl 186.77\n",
"| epoch 3 | 1200/13484 batches | lr 4.51 | ms/batch 115.14 | loss 5.22 | ppl 184.68\n",
"| epoch 3 | 1400/13484 batches | lr 4.51 | ms/batch 115.18 | loss 5.20 | ppl 181.17\n",
"| epoch 3 | 1600/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.25 | ppl 191.20\n",
"| epoch 3 | 1800/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.23 | ppl 186.87\n",
"| epoch 3 | 2000/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.19 | ppl 180.16\n",
"| epoch 3 | 2200/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.21 | ppl 183.82\n",
"| epoch 3 | 2400/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.21 | ppl 182.76\n",
"| epoch 3 | 2600/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.19 | ppl 180.25\n",
"| epoch 3 | 2800/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.22 | ppl 185.75\n",
"| epoch 3 | 3000/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.21 | ppl 183.06\n",
"| epoch 3 | 3200/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.17 | ppl 176.28\n",
"| epoch 3 | 3400/13484 batches | lr 4.51 | ms/batch 115.16 | loss 5.24 | ppl 187.88\n",
"| epoch 3 | 3600/13484 batches | lr 4.51 | ms/batch 115.16 | loss 5.21 | ppl 182.87\n",
"| epoch 3 | 3800/13484 batches | lr 4.51 | ms/batch 115.18 | loss 5.21 | ppl 182.52\n",
"| epoch 3 | 4000/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.15 | ppl 172.43\n",
"| epoch 3 | 4200/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.18 | ppl 177.72\n",
"| epoch 3 | 4400/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.19 | ppl 179.22\n",
"| epoch 3 | 4600/13484 batches | lr 4.51 | ms/batch 115.26 | loss 5.24 | ppl 187.99\n",
"| epoch 3 | 4800/13484 batches | lr 4.51 | ms/batch 115.17 | loss 5.24 | ppl 188.20\n",
"| epoch 3 | 5000/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.22 | ppl 184.24\n",
"| epoch 3 | 5200/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.21 | ppl 184.01\n",
"| epoch 3 | 5400/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.19 | ppl 178.94\n",
"| epoch 3 | 5600/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.19 | ppl 180.15\n",
"| epoch 3 | 5800/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.20 | ppl 181.24\n",
"| epoch 3 | 6000/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.22 | ppl 184.08\n",
"| epoch 3 | 6200/13484 batches | lr 4.51 | ms/batch 115.29 | loss 5.24 | ppl 187.77\n",
"| epoch 3 | 6400/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.21 | ppl 182.36\n",
"| epoch 3 | 6600/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.25 | ppl 190.52\n",
"| epoch 3 | 6800/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.20 | ppl 180.56\n",
"| epoch 3 | 7000/13484 batches | lr 4.51 | ms/batch 115.14 | loss 5.23 | ppl 186.73\n",
"| epoch 3 | 7200/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.19 | ppl 179.90\n",
"| epoch 3 | 7400/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.21 | ppl 182.43\n",
"| epoch 3 | 7600/13484 batches | lr 4.51 | ms/batch 115.09 | loss 5.20 | ppl 181.48\n",
"| epoch 3 | 7800/13484 batches | lr 4.51 | ms/batch 115.26 | loss 5.22 | ppl 185.25\n",
"| epoch 3 | 8000/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.18 | ppl 178.05\n",
"| epoch 3 | 8200/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.18 | ppl 178.41\n",
"| epoch 3 | 8400/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.20 | ppl 181.07\n",
"| epoch 3 | 8600/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.20 | ppl 182.10\n",
"| epoch 3 | 8800/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.18 | ppl 177.86\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 3 | 9000/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.21 | ppl 182.88\n",
"| epoch 3 | 9200/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.21 | ppl 183.47\n",
"| epoch 3 | 9400/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.22 | ppl 184.25\n",
"| epoch 3 | 9600/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.18 | ppl 177.24\n",
"| epoch 3 | 9800/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.16 | ppl 174.24\n",
"| epoch 3 | 10000/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.21 | ppl 182.29\n",
"| epoch 3 | 10200/13484 batches | lr 4.51 | ms/batch 115.13 | loss 5.17 | ppl 175.34\n",
"| epoch 3 | 10400/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.14 | ppl 170.79\n",
"| epoch 3 | 10600/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.17 | ppl 176.55\n",
"| epoch 3 | 10800/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.22 | ppl 185.77\n",
"| epoch 3 | 11000/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.19 | ppl 179.38\n",
"| epoch 3 | 11200/13484 batches | lr 4.51 | ms/batch 115.30 | loss 5.19 | ppl 179.59\n",
"| epoch 3 | 11400/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.14 | ppl 171.38\n",
"| epoch 3 | 11600/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.18 | ppl 178.51\n",
"| epoch 3 | 11800/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.16 | ppl 174.52\n",
"| epoch 3 | 12000/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.20 | ppl 181.28\n",
"| epoch 3 | 12200/13484 batches | lr 4.51 | ms/batch 115.13 | loss 5.14 | ppl 170.89\n",
"| epoch 3 | 12400/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.14 | ppl 169.88\n",
"| epoch 3 | 12600/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.15 | ppl 172.67\n",
"| epoch 3 | 12800/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.18 | ppl 176.89\n",
"| epoch 3 | 13000/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.19 | ppl 179.90\n",
"| epoch 3 | 13200/13484 batches | lr 4.51 | ms/batch 115.29 | loss 5.17 | ppl 175.90\n",
"| epoch 3 | 13400/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.20 | ppl 182.07\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 3 | time: 1625.86s | valid loss 5.19 | valid ppl 178.94\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 4 | 200/13484 batches | lr 4.29 | ms/batch 115.82 | loss 5.22 | ppl 184.03\n",
"| epoch 4 | 400/13484 batches | lr 4.29 | ms/batch 115.24 | loss 5.16 | ppl 174.84\n",
"| epoch 4 | 600/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.14 | ppl 170.18\n",
"| epoch 4 | 800/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.15 | ppl 171.64\n",
"| epoch 4 | 1000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.14 | ppl 171.26\n",
"| epoch 4 | 1200/13484 batches | lr 4.29 | ms/batch 115.14 | loss 5.14 | ppl 171.18\n",
"| epoch 4 | 1400/13484 batches | lr 4.29 | ms/batch 115.13 | loss 5.12 | ppl 166.55\n",
"| epoch 4 | 1600/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.17 | ppl 176.35\n",
"| epoch 4 | 1800/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.15 | ppl 172.34\n",
"| epoch 4 | 2000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.13 | ppl 169.46\n",
"| epoch 4 | 2200/13484 batches | lr 4.29 | ms/batch 115.23 | loss 5.16 | ppl 173.74\n",
"| epoch 4 | 2400/13484 batches | lr 4.29 | ms/batch 115.14 | loss 5.14 | ppl 170.76\n",
"| epoch 4 | 2600/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.11 | ppl 165.36\n",
"| epoch 4 | 2800/13484 batches | lr 4.29 | ms/batch 115.14 | loss 5.15 | ppl 173.15\n",
"| epoch 4 | 3000/13484 batches | lr 4.29 | ms/batch 115.15 | loss 5.14 | ppl 171.39\n",
"| epoch 4 | 3200/13484 batches | lr 4.29 | ms/batch 115.27 | loss 5.10 | ppl 164.27\n",
"| epoch 4 | 3400/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.16 | ppl 174.64\n",
"| epoch 4 | 3600/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.13 | ppl 168.98\n",
"| epoch 4 | 3800/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.12 | ppl 167.42\n",
"| epoch 4 | 4000/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.08 | ppl 161.02\n",
"| epoch 4 | 4200/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.11 | ppl 165.33\n",
"| epoch 4 | 4400/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.11 | ppl 165.61\n",
"| epoch 4 | 4600/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.16 | ppl 173.90\n",
"| epoch 4 | 4800/13484 batches | lr 4.29 | ms/batch 115.27 | loss 5.15 | ppl 172.81\n",
"| epoch 4 | 5000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.14 | ppl 169.98\n",
"| epoch 4 | 5200/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.13 | ppl 168.94\n",
"| epoch 4 | 5400/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.10 | ppl 164.28\n",
"| epoch 4 | 5600/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.12 | ppl 167.23\n",
"| epoch 4 | 5800/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.12 | ppl 167.63\n",
"| epoch 4 | 6000/13484 batches | lr 4.29 | ms/batch 115.34 | loss 5.14 | ppl 170.26\n",
"| epoch 4 | 6200/13484 batches | lr 4.29 | ms/batch 115.31 | loss 5.18 | ppl 177.13\n",
"| epoch 4 | 6400/13484 batches | lr 4.29 | ms/batch 115.27 | loss 5.13 | ppl 169.45\n",
"| epoch 4 | 6600/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.16 | ppl 174.83\n",
"| epoch 4 | 6800/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.11 | ppl 165.20\n",
"| epoch 4 | 7000/13484 batches | lr 4.29 | ms/batch 115.19 | loss 5.16 | ppl 174.72\n",
"| epoch 4 | 7200/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.12 | ppl 167.83\n",
"| epoch 4 | 7400/13484 batches | lr 4.29 | ms/batch 115.29 | loss 5.12 | ppl 167.13\n",
"| epoch 4 | 7600/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.13 | ppl 168.29\n",
"| epoch 4 | 7800/13484 batches | lr 4.29 | ms/batch 115.30 | loss 5.12 | ppl 167.88\n",
"| epoch 4 | 8000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.11 | ppl 165.65\n",
"| epoch 4 | 8200/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.10 | ppl 164.16\n",
"| epoch 4 | 8400/13484 batches | lr 4.29 | ms/batch 115.29 | loss 5.12 | ppl 166.71\n",
"| epoch 4 | 8600/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.14 | ppl 169.91\n",
"| epoch 4 | 8800/13484 batches | lr 4.29 | ms/batch 115.30 | loss 5.11 | ppl 166.00\n",
"| epoch 4 | 9000/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.13 | ppl 169.67\n",
"| epoch 4 | 9200/13484 batches | lr 4.29 | ms/batch 115.31 | loss 5.13 | ppl 169.46\n",
"| epoch 4 | 9400/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.15 | ppl 171.85\n",
"| epoch 4 | 9600/13484 batches | lr 4.29 | ms/batch 115.29 | loss 5.11 | ppl 165.01\n",
"| epoch 4 | 9800/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.09 | ppl 162.51\n",
"| epoch 4 | 10000/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.12 | ppl 167.81\n",
"| epoch 4 | 10200/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.10 | ppl 163.43\n",
"| epoch 4 | 10400/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.07 | ppl 158.79\n",
"| epoch 4 | 10600/13484 batches | lr 4.29 | ms/batch 115.30 | loss 5.10 | ppl 163.54\n",
"| epoch 4 | 10800/13484 batches | lr 4.29 | ms/batch 115.39 | loss 5.12 | ppl 167.03\n",
"| epoch 4 | 11000/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.11 | ppl 166.04\n",
"| epoch 4 | 11200/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.11 | ppl 165.67\n",
"| epoch 4 | 11400/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.06 | ppl 157.42\n",
"| epoch 4 | 11600/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.10 | ppl 164.17\n",
"| epoch 4 | 11800/13484 batches | lr 4.29 | ms/batch 115.36 | loss 5.07 | ppl 159.41\n",
"| epoch 4 | 12000/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.13 | ppl 168.33\n",
"| epoch 4 | 12200/13484 batches | lr 4.29 | ms/batch 115.24 | loss 5.05 | ppl 155.52\n",
"| epoch 4 | 12400/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.07 | ppl 159.62\n",
"| epoch 4 | 12600/13484 batches | lr 4.29 | ms/batch 115.24 | loss 5.09 | ppl 161.65\n",
"| epoch 4 | 12800/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.10 | ppl 164.49\n",
"| epoch 4 | 13000/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.10 | ppl 163.47\n",
"| epoch 4 | 13200/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.09 | ppl 162.89\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 4 | 13400/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.12 | ppl 166.66\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 4 | time: 1626.36s | valid loss 5.13 | valid ppl 168.54\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 5 | 200/13484 batches | lr 4.07 | ms/batch 115.90 | loss 5.14 | ppl 170.65\n",
"| epoch 5 | 400/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.09 | ppl 163.18\n",
"| epoch 5 | 600/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.07 | ppl 159.22\n",
"| epoch 5 | 800/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.08 | ppl 160.60\n",
"| epoch 5 | 1000/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.08 | ppl 160.49\n",
"| epoch 5 | 1200/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.07 | ppl 158.86\n",
"| epoch 5 | 1400/13484 batches | lr 4.07 | ms/batch 115.14 | loss 5.06 | ppl 156.88\n",
"| epoch 5 | 1600/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.10 | ppl 164.68\n",
"| epoch 5 | 1800/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.09 | ppl 161.68\n",
"| epoch 5 | 2000/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.05 | ppl 156.19\n",
"| epoch 5 | 2200/13484 batches | lr 4.07 | ms/batch 115.16 | loss 5.06 | ppl 157.65\n",
"| epoch 5 | 2400/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.05 | ppl 156.29\n",
"| epoch 5 | 2600/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.04 | ppl 155.08\n",
"| epoch 5 | 2800/13484 batches | lr 4.07 | ms/batch 115.12 | loss 5.08 | ppl 160.79\n",
"| epoch 5 | 3000/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.06 | ppl 157.93\n",
"| epoch 5 | 3200/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.03 | ppl 153.59\n",
"| epoch 5 | 3400/13484 batches | lr 4.07 | ms/batch 115.24 | loss 5.10 | ppl 164.69\n",
"| epoch 5 | 3600/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.07 | ppl 159.67\n",
"| epoch 5 | 3800/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.05 | ppl 156.33\n",
"| epoch 5 | 4000/13484 batches | lr 4.07 | ms/batch 115.30 | loss 5.00 | ppl 148.52\n",
"| epoch 5 | 4200/13484 batches | lr 4.07 | ms/batch 115.16 | loss 5.03 | ppl 153.04\n",
"| epoch 5 | 4400/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.04 | ppl 155.12\n",
"| epoch 5 | 4600/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.09 | ppl 162.86\n",
"| epoch 5 | 4800/13484 batches | lr 4.07 | ms/batch 115.21 | loss 5.07 | ppl 159.17\n",
"| epoch 5 | 5000/13484 batches | lr 4.07 | ms/batch 115.27 | loss 5.06 | ppl 157.50\n",
"| epoch 5 | 5200/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.06 | ppl 157.70\n",
"| epoch 5 | 5400/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.04 | ppl 154.31\n",
"| epoch 5 | 5600/13484 batches | lr 4.07 | ms/batch 115.18 | loss 5.05 | ppl 156.47\n",
"| epoch 5 | 5800/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.06 | ppl 157.43\n",
"| epoch 5 | 6000/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.07 | ppl 159.33\n",
"| epoch 5 | 6200/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.09 | ppl 163.19\n",
"| epoch 5 | 6400/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.07 | ppl 159.77\n",
"| epoch 5 | 6600/13484 batches | lr 4.07 | ms/batch 115.28 | loss 5.09 | ppl 163.17\n",
"| epoch 5 | 6800/13484 batches | lr 4.07 | ms/batch 115.11 | loss 5.03 | ppl 153.03\n",
"| epoch 5 | 7000/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.09 | ppl 161.90\n",
"| epoch 5 | 7200/13484 batches | lr 4.07 | ms/batch 115.21 | loss 5.06 | ppl 156.90\n",
"| epoch 5 | 7400/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.07 | ppl 159.02\n",
"| epoch 5 | 7600/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.05 | ppl 156.02\n",
"| epoch 5 | 7800/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.06 | ppl 157.20\n",
"| epoch 5 | 8000/13484 batches | lr 4.07 | ms/batch 115.20 | loss 5.04 | ppl 154.56\n",
"| epoch 5 | 8200/13484 batches | lr 4.07 | ms/batch 115.20 | loss 5.03 | ppl 152.46\n",
"| epoch 5 | 8400/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.06 | ppl 157.62\n",
"| epoch 5 | 8600/13484 batches | lr 4.07 | ms/batch 115.28 | loss 5.07 | ppl 158.74\n",
"| epoch 5 | 8800/13484 batches | lr 4.07 | ms/batch 115.30 | loss 5.04 | ppl 154.53\n",
"| epoch 5 | 9000/13484 batches | lr 4.07 | ms/batch 115.31 | loss 5.06 | ppl 157.02\n",
"| epoch 5 | 9200/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.07 | ppl 159.14\n",
"| epoch 5 | 9400/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.07 | ppl 159.15\n",
"| epoch 5 | 9600/13484 batches | lr 4.07 | ms/batch 115.24 | loss 5.04 | ppl 153.89\n",
"| epoch 5 | 9800/13484 batches | lr 4.07 | ms/batch 115.27 | loss 5.02 | ppl 151.96\n",
"| epoch 5 | 10000/13484 batches | lr 4.07 | ms/batch 115.24 | loss 5.05 | ppl 156.58\n",
"| epoch 5 | 10200/13484 batches | lr 4.07 | ms/batch 115.30 | loss 5.02 | ppl 152.10\n",
"| epoch 5 | 10400/13484 batches | lr 4.07 | ms/batch 115.33 | loss 4.98 | ppl 146.11\n",
"| epoch 5 | 10600/13484 batches | lr 4.07 | ms/batch 115.26 | loss 5.03 | ppl 153.27\n",
"| epoch 5 | 10800/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.06 | ppl 157.44\n",
"| epoch 5 | 11000/13484 batches | lr 4.07 | ms/batch 115.33 | loss 5.05 | ppl 156.34\n",
"| epoch 5 | 11200/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.04 | ppl 154.36\n",
"| epoch 5 | 11400/13484 batches | lr 4.07 | ms/batch 115.27 | loss 5.00 | ppl 148.51\n",
"| epoch 5 | 11600/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.03 | ppl 153.09\n",
"| epoch 5 | 11800/13484 batches | lr 4.07 | ms/batch 115.26 | loss 5.00 | ppl 148.85\n",
"| epoch 5 | 12000/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.06 | ppl 156.93\n",
"| epoch 5 | 12200/13484 batches | lr 4.07 | ms/batch 115.22 | loss 4.98 | ppl 145.89\n",
"| epoch 5 | 12400/13484 batches | lr 4.07 | ms/batch 115.20 | loss 5.00 | ppl 148.86\n",
"| epoch 5 | 12600/13484 batches | lr 4.07 | ms/batch 115.33 | loss 5.02 | ppl 151.32\n",
"| epoch 5 | 12800/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.04 | ppl 154.42\n",
"| epoch 5 | 13000/13484 batches | lr 4.07 | ms/batch 115.28 | loss 5.03 | ppl 152.95\n",
"| epoch 5 | 13200/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.03 | ppl 153.49\n",
"| epoch 5 | 13400/13484 batches | lr 4.07 | ms/batch 115.35 | loss 5.05 | ppl 155.92\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 5 | time: 1625.93s | valid loss 5.10 | valid ppl 164.06\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 6 | 200/13484 batches | lr 3.87 | ms/batch 115.82 | loss 5.07 | ppl 159.79\n",
"| epoch 6 | 400/13484 batches | lr 3.87 | ms/batch 115.22 | loss 5.04 | ppl 153.75\n",
"| epoch 6 | 600/13484 batches | lr 3.87 | ms/batch 115.14 | loss 5.01 | ppl 150.20\n",
"| epoch 6 | 800/13484 batches | lr 3.87 | ms/batch 115.25 | loss 5.01 | ppl 150.24\n",
"| epoch 6 | 1000/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.01 | ppl 149.42\n",
"| epoch 6 | 1200/13484 batches | lr 3.87 | ms/batch 115.09 | loss 5.01 | ppl 150.28\n",
"| epoch 6 | 1400/13484 batches | lr 3.87 | ms/batch 115.18 | loss 5.00 | ppl 148.53\n",
"| epoch 6 | 1600/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.05 | ppl 156.45\n",
"| epoch 6 | 1800/13484 batches | lr 3.87 | ms/batch 115.17 | loss 5.02 | ppl 151.97\n",
"| epoch 6 | 2000/13484 batches | lr 3.87 | ms/batch 115.14 | loss 5.00 | ppl 147.68\n",
"| epoch 6 | 2200/13484 batches | lr 3.87 | ms/batch 115.22 | loss 5.00 | ppl 148.99\n",
"| epoch 6 | 2400/13484 batches | lr 3.87 | ms/batch 115.19 | loss 5.00 | ppl 147.82\n",
"| epoch 6 | 2600/13484 batches | lr 3.87 | ms/batch 115.19 | loss 4.98 | ppl 145.20\n",
"| epoch 6 | 2800/13484 batches | lr 3.87 | ms/batch 115.20 | loss 5.02 | ppl 152.00\n",
"| epoch 6 | 3000/13484 batches | lr 3.87 | ms/batch 115.20 | loss 5.01 | ppl 149.24\n",
"| epoch 6 | 3200/13484 batches | lr 3.87 | ms/batch 115.23 | loss 4.98 | ppl 145.09\n",
"| epoch 6 | 3400/13484 batches | lr 3.87 | ms/batch 115.35 | loss 5.03 | ppl 153.68\n",
"| epoch 6 | 3600/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.01 | ppl 149.34\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 6 | 3800/13484 batches | lr 3.87 | ms/batch 115.20 | loss 5.00 | ppl 148.07\n",
"| epoch 6 | 4000/13484 batches | lr 3.87 | ms/batch 115.32 | loss 4.94 | ppl 140.04\n",
"| epoch 6 | 4200/13484 batches | lr 3.87 | ms/batch 115.21 | loss 4.97 | ppl 144.64\n",
"| epoch 6 | 4400/13484 batches | lr 3.87 | ms/batch 115.20 | loss 4.99 | ppl 146.48\n",
"| epoch 6 | 4600/13484 batches | lr 3.87 | ms/batch 115.18 | loss 5.03 | ppl 153.49\n",
"| epoch 6 | 4800/13484 batches | lr 3.87 | ms/batch 115.30 | loss 5.01 | ppl 150.20\n",
"| epoch 6 | 5000/13484 batches | lr 3.87 | ms/batch 115.24 | loss 5.00 | ppl 148.23\n",
"| epoch 6 | 5200/13484 batches | lr 3.87 | ms/batch 115.22 | loss 5.00 | ppl 148.51\n",
"| epoch 6 | 5400/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.98 | ppl 145.45\n",
"| epoch 6 | 5600/13484 batches | lr 3.87 | ms/batch 115.25 | loss 4.99 | ppl 146.84\n",
"| epoch 6 | 5800/13484 batches | lr 3.87 | ms/batch 115.29 | loss 4.99 | ppl 147.24\n",
"| epoch 6 | 6000/13484 batches | lr 3.87 | ms/batch 115.19 | loss 5.01 | ppl 150.09\n",
"| epoch 6 | 6200/13484 batches | lr 3.87 | ms/batch 115.21 | loss 5.03 | ppl 152.86\n",
"| epoch 6 | 6400/13484 batches | lr 3.87 | ms/batch 115.17 | loss 5.02 | ppl 150.83\n",
"| epoch 6 | 6600/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.05 | ppl 155.25\n",
"| epoch 6 | 6800/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.00 | ppl 148.10\n",
"| epoch 6 | 7000/13484 batches | lr 3.87 | ms/batch 115.35 | loss 5.03 | ppl 152.52\n",
"| epoch 6 | 7200/13484 batches | lr 3.87 | ms/batch 115.25 | loss 5.00 | ppl 148.62\n",
"| epoch 6 | 7400/13484 batches | lr 3.87 | ms/batch 115.30 | loss 5.00 | ppl 148.56\n",
"| epoch 6 | 7600/13484 batches | lr 3.87 | ms/batch 115.25 | loss 4.99 | ppl 147.28\n",
"| epoch 6 | 7800/13484 batches | lr 3.87 | ms/batch 115.24 | loss 5.00 | ppl 147.93\n",
"| epoch 6 | 8000/13484 batches | lr 3.87 | ms/batch 115.24 | loss 4.98 | ppl 145.76\n",
"| epoch 6 | 8200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.97 | ppl 143.39\n",
"| epoch 6 | 8400/13484 batches | lr 3.87 | ms/batch 115.24 | loss 4.99 | ppl 147.14\n",
"| epoch 6 | 8600/13484 batches | lr 3.87 | ms/batch 115.27 | loss 5.00 | ppl 148.00\n",
"| epoch 6 | 8800/13484 batches | lr 3.87 | ms/batch 115.35 | loss 4.98 | ppl 145.27\n",
"| epoch 6 | 9000/13484 batches | lr 3.87 | ms/batch 115.27 | loss 5.01 | ppl 150.06\n",
"| epoch 6 | 9200/13484 batches | lr 3.87 | ms/batch 115.21 | loss 5.01 | ppl 150.09\n",
"| epoch 6 | 9400/13484 batches | lr 3.87 | ms/batch 115.28 | loss 5.01 | ppl 150.08\n",
"| epoch 6 | 9600/13484 batches | lr 3.87 | ms/batch 115.16 | loss 4.99 | ppl 147.55\n",
"| epoch 6 | 9800/13484 batches | lr 3.87 | ms/batch 115.27 | loss 4.97 | ppl 143.67\n",
"| epoch 6 | 10000/13484 batches | lr 3.87 | ms/batch 115.20 | loss 4.99 | ppl 147.66\n",
"| epoch 6 | 10200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.95 | ppl 141.61\n",
"| epoch 6 | 10400/13484 batches | lr 3.87 | ms/batch 115.20 | loss 4.93 | ppl 138.76\n",
"| epoch 6 | 10600/13484 batches | lr 3.87 | ms/batch 115.28 | loss 4.97 | ppl 144.59\n",
"| epoch 6 | 10800/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.01 | ppl 149.16\n",
"| epoch 6 | 11000/13484 batches | lr 3.87 | ms/batch 115.29 | loss 5.00 | ppl 148.35\n",
"| epoch 6 | 11200/13484 batches | lr 3.87 | ms/batch 115.29 | loss 5.01 | ppl 149.31\n",
"| epoch 6 | 11400/13484 batches | lr 3.87 | ms/batch 115.29 | loss 4.95 | ppl 141.26\n",
"| epoch 6 | 11600/13484 batches | lr 3.87 | ms/batch 115.34 | loss 4.98 | ppl 145.07\n",
"| epoch 6 | 11800/13484 batches | lr 3.87 | ms/batch 115.28 | loss 4.94 | ppl 140.00\n",
"| epoch 6 | 12000/13484 batches | lr 3.87 | ms/batch 115.19 | loss 5.00 | ppl 147.85\n",
"| epoch 6 | 12200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.93 | ppl 137.74\n",
"| epoch 6 | 12400/13484 batches | lr 3.87 | ms/batch 115.26 | loss 4.95 | ppl 140.89\n",
"| epoch 6 | 12600/13484 batches | lr 3.87 | ms/batch 115.38 | loss 4.97 | ppl 143.33\n",
"| epoch 6 | 12800/13484 batches | lr 3.87 | ms/batch 115.29 | loss 4.98 | ppl 145.29\n",
"| epoch 6 | 13000/13484 batches | lr 3.87 | ms/batch 115.37 | loss 4.97 | ppl 144.45\n",
"| epoch 6 | 13200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.98 | ppl 146.13\n",
"| epoch 6 | 13400/13484 batches | lr 3.87 | ms/batch 115.33 | loss 5.00 | ppl 148.36\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 6 | time: 1626.41s | valid loss 5.09 | valid ppl 162.11\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 7 | 200/13484 batches | lr 3.68 | ms/batch 115.82 | loss 5.02 | ppl 151.17\n",
"| epoch 7 | 400/13484 batches | lr 3.68 | ms/batch 115.20 | loss 4.98 | ppl 144.87\n",
"| epoch 7 | 600/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.96 | ppl 142.39\n",
"| epoch 7 | 800/13484 batches | lr 3.68 | ms/batch 115.16 | loss 4.96 | ppl 142.45\n",
"| epoch 7 | 1000/13484 batches | lr 3.68 | ms/batch 115.08 | loss 4.96 | ppl 142.21\n",
"| epoch 7 | 1200/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.96 | ppl 142.21\n",
"| epoch 7 | 1400/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.94 | ppl 139.97\n",
"| epoch 7 | 1600/13484 batches | lr 3.68 | ms/batch 115.13 | loss 4.99 | ppl 146.87\n",
"| epoch 7 | 1800/13484 batches | lr 3.68 | ms/batch 115.11 | loss 4.97 | ppl 144.27\n",
"| epoch 7 | 2000/13484 batches | lr 3.68 | ms/batch 115.14 | loss 4.94 | ppl 139.63\n",
"| epoch 7 | 2200/13484 batches | lr 3.68 | ms/batch 115.13 | loss 4.94 | ppl 140.28\n",
"| epoch 7 | 2400/13484 batches | lr 3.68 | ms/batch 115.14 | loss 4.94 | ppl 140.42\n",
"| epoch 7 | 2600/13484 batches | lr 3.68 | ms/batch 115.20 | loss 4.93 | ppl 138.37\n",
"| epoch 7 | 2800/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.97 | ppl 144.51\n",
"| epoch 7 | 3000/13484 batches | lr 3.68 | ms/batch 115.22 | loss 4.95 | ppl 141.43\n",
"| epoch 7 | 3200/13484 batches | lr 3.68 | ms/batch 115.17 | loss 4.92 | ppl 137.29\n",
"| epoch 7 | 3400/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.98 | ppl 145.62\n",
"| epoch 7 | 3600/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.95 | ppl 141.60\n",
"| epoch 7 | 3800/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.94 | ppl 139.88\n",
"| epoch 7 | 4000/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.89 | ppl 133.49\n",
"| epoch 7 | 4200/13484 batches | lr 3.68 | ms/batch 115.17 | loss 4.93 | ppl 138.21\n",
"| epoch 7 | 4400/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.94 | ppl 139.14\n",
"| epoch 7 | 4600/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.98 | ppl 145.67\n",
"| epoch 7 | 4800/13484 batches | lr 3.68 | ms/batch 115.25 | loss 4.96 | ppl 143.05\n",
"| epoch 7 | 5000/13484 batches | lr 3.68 | ms/batch 115.20 | loss 4.95 | ppl 141.27\n",
"| epoch 7 | 5200/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.95 | ppl 140.78\n",
"| epoch 7 | 5400/13484 batches | lr 3.68 | ms/batch 115.35 | loss 4.93 | ppl 137.98\n",
"| epoch 7 | 5600/13484 batches | lr 3.68 | ms/batch 115.29 | loss 4.94 | ppl 139.66\n",
"| epoch 7 | 5800/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.94 | ppl 139.99\n",
"| epoch 7 | 6000/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.96 | ppl 142.34\n",
"| epoch 7 | 6200/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.99 | ppl 146.32\n",
"| epoch 7 | 6400/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.96 | ppl 142.33\n",
"| epoch 7 | 6600/13484 batches | lr 3.68 | ms/batch 115.22 | loss 4.99 | ppl 146.69\n",
"| epoch 7 | 6800/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.93 | ppl 137.90\n",
"| epoch 7 | 7000/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.98 | ppl 145.72\n",
"| epoch 7 | 7200/13484 batches | lr 3.68 | ms/batch 115.25 | loss 4.94 | ppl 140.06\n",
"| epoch 7 | 7400/13484 batches | lr 3.68 | ms/batch 115.14 | loss 4.94 | ppl 140.43\n",
"| epoch 7 | 7600/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.95 | ppl 140.71\n",
"| epoch 7 | 7800/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.94 | ppl 140.23\n",
"| epoch 7 | 8000/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.93 | ppl 138.76\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 7 | 8200/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.92 | ppl 136.77\n",
"| epoch 7 | 8400/13484 batches | lr 3.68 | ms/batch 115.33 | loss 4.94 | ppl 139.78\n",
"| epoch 7 | 8600/13484 batches | lr 3.68 | ms/batch 115.31 | loss 4.95 | ppl 141.01\n",
"| epoch 7 | 8800/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.93 | ppl 138.80\n",
"| epoch 7 | 9000/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.95 | ppl 141.73\n",
"| epoch 7 | 9200/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.97 | ppl 144.05\n",
"| epoch 7 | 9400/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.97 | ppl 144.66\n",
"| epoch 7 | 9600/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.93 | ppl 138.12\n",
"| epoch 7 | 9800/13484 batches | lr 3.68 | ms/batch 115.27 | loss 4.91 | ppl 135.39\n",
"| epoch 7 | 10000/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.94 | ppl 140.12\n",
"| epoch 7 | 10200/13484 batches | lr 3.68 | ms/batch 115.27 | loss 4.90 | ppl 134.73\n",
"| epoch 7 | 10400/13484 batches | lr 3.68 | ms/batch 115.29 | loss 4.87 | ppl 130.45\n",
"| epoch 7 | 10600/13484 batches | lr 3.68 | ms/batch 115.35 | loss 4.92 | ppl 137.36\n",
"| epoch 7 | 10800/13484 batches | lr 3.68 | ms/batch 115.29 | loss 4.94 | ppl 140.35\n",
"| epoch 7 | 11000/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.95 | ppl 141.22\n",
"| epoch 7 | 11200/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.94 | ppl 139.33\n",
"| epoch 7 | 11400/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.89 | ppl 133.07\n",
"| epoch 7 | 11600/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.93 | ppl 137.82\n",
"| epoch 7 | 11800/13484 batches | lr 3.68 | ms/batch 115.33 | loss 4.89 | ppl 132.51\n",
"| epoch 7 | 12000/13484 batches | lr 3.68 | ms/batch 115.32 | loss 4.94 | ppl 139.89\n",
"| epoch 7 | 12200/13484 batches | lr 3.68 | ms/batch 115.25 | loss 4.88 | ppl 131.43\n",
"| epoch 7 | 12400/13484 batches | lr 3.68 | ms/batch 115.32 | loss 4.89 | ppl 133.23\n",
"| epoch 7 | 12600/13484 batches | lr 3.68 | ms/batch 115.30 | loss 4.92 | ppl 136.69\n",
"| epoch 7 | 12800/13484 batches | lr 3.68 | ms/batch 115.27 | loss 4.94 | ppl 139.23\n",
"| epoch 7 | 13000/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.92 | ppl 136.46\n",
"| epoch 7 | 13200/13484 batches | lr 3.68 | ms/batch 115.31 | loss 4.92 | ppl 137.53\n",
"| epoch 7 | 13400/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.94 | ppl 140.23\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 7 | time: 1626.06s | valid loss 5.05 | valid ppl 155.94\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 8 | 200/13484 batches | lr 3.49 | ms/batch 115.85 | loss 4.97 | ppl 143.91\n",
"| epoch 8 | 400/13484 batches | lr 3.49 | ms/batch 115.20 | loss 4.93 | ppl 138.69\n",
"| epoch 8 | 600/13484 batches | lr 3.49 | ms/batch 115.25 | loss 4.92 | ppl 137.31\n",
"| epoch 8 | 800/13484 batches | lr 3.49 | ms/batch 115.21 | loss 4.91 | ppl 135.14\n",
"| epoch 8 | 1000/13484 batches | lr 3.49 | ms/batch 115.09 | loss 4.91 | ppl 136.03\n",
"| epoch 8 | 1200/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.91 | ppl 135.25\n",
"| epoch 8 | 1400/13484 batches | lr 3.49 | ms/batch 115.16 | loss 4.89 | ppl 132.45\n",
"| epoch 8 | 1600/13484 batches | lr 3.49 | ms/batch 115.19 | loss 4.94 | ppl 139.52\n",
"| epoch 8 | 1800/13484 batches | lr 3.49 | ms/batch 115.13 | loss 4.92 | ppl 136.90\n",
"| epoch 8 | 2000/13484 batches | lr 3.49 | ms/batch 115.23 | loss 4.89 | ppl 132.80\n",
"| epoch 8 | 2200/13484 batches | lr 3.49 | ms/batch 115.12 | loss 4.89 | ppl 132.74\n",
"| epoch 8 | 2400/13484 batches | lr 3.49 | ms/batch 115.25 | loss 4.90 | ppl 133.92\n",
"| epoch 8 | 2600/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.88 | ppl 131.38\n",
"| epoch 8 | 2800/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.92 | ppl 137.14\n",
"| epoch 8 | 3000/13484 batches | lr 3.49 | ms/batch 115.18 | loss 4.90 | ppl 134.47\n",
"| epoch 8 | 3200/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.87 | ppl 130.24\n",
"| epoch 8 | 3400/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.93 | ppl 139.00\n",
"| epoch 8 | 3600/13484 batches | lr 3.49 | ms/batch 115.20 | loss 4.91 | ppl 135.20\n",
"| epoch 8 | 3800/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.90 | ppl 133.96\n",
"| epoch 8 | 4000/13484 batches | lr 3.49 | ms/batch 115.19 | loss 4.84 | ppl 127.05\n",
"| epoch 8 | 4200/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.87 | ppl 130.76\n",
"| epoch 8 | 4400/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.88 | ppl 132.28\n",
"| epoch 8 | 4600/13484 batches | lr 3.49 | ms/batch 115.36 | loss 4.93 | ppl 138.46\n",
"| epoch 8 | 4800/13484 batches | lr 3.49 | ms/batch 115.18 | loss 4.91 | ppl 135.37\n",
"| epoch 8 | 5000/13484 batches | lr 3.49 | ms/batch 115.23 | loss 4.90 | ppl 134.12\n",
"| epoch 8 | 5200/13484 batches | lr 3.49 | ms/batch 115.21 | loss 4.90 | ppl 134.65\n",
"| epoch 8 | 5400/13484 batches | lr 3.49 | ms/batch 115.35 | loss 4.87 | ppl 130.93\n",
"| epoch 8 | 5600/13484 batches | lr 3.49 | ms/batch 115.35 | loss 4.89 | ppl 133.28\n",
"| epoch 8 | 5800/13484 batches | lr 3.49 | ms/batch 115.23 | loss 4.89 | ppl 132.54\n",
"| epoch 8 | 6000/13484 batches | lr 3.49 | ms/batch 115.22 | loss 4.91 | ppl 135.15\n",
"| epoch 8 | 6200/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.94 | ppl 139.25\n",
"| epoch 8 | 6400/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.91 | ppl 135.37\n",
"| epoch 8 | 6600/13484 batches | lr 3.49 | ms/batch 115.17 | loss 4.94 | ppl 139.28\n",
"| epoch 8 | 6800/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.88 | ppl 132.05\n",
"| epoch 8 | 7000/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.92 | ppl 137.41\n",
"| epoch 8 | 7200/13484 batches | lr 3.49 | ms/batch 115.34 | loss 4.90 | ppl 133.68\n",
"| epoch 8 | 7400/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.89 | ppl 133.58\n",
"| epoch 8 | 7600/13484 batches | lr 3.49 | ms/batch 115.26 | loss 4.90 | ppl 133.64\n",
"| epoch 8 | 7800/13484 batches | lr 3.49 | ms/batch 115.33 | loss 4.89 | ppl 133.55\n",
"| epoch 8 | 8000/13484 batches | lr 3.49 | ms/batch 115.17 | loss 4.88 | ppl 132.23\n",
"| epoch 8 | 8200/13484 batches | lr 3.49 | ms/batch 115.25 | loss 4.87 | ppl 129.93\n",
"| epoch 8 | 8400/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.89 | ppl 133.16\n",
"| epoch 8 | 8600/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.89 | ppl 133.49\n",
"| epoch 8 | 8800/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.88 | ppl 131.42\n",
"| epoch 8 | 9000/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.89 | ppl 133.59\n",
"| epoch 8 | 9200/13484 batches | lr 3.49 | ms/batch 115.28 | loss 4.91 | ppl 136.20\n",
"| epoch 8 | 9400/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.91 | ppl 135.54\n",
"| epoch 8 | 9600/13484 batches | lr 3.49 | ms/batch 115.32 | loss 4.88 | ppl 131.19\n",
"| epoch 8 | 9800/13484 batches | lr 3.49 | ms/batch 115.34 | loss 4.86 | ppl 128.72\n",
"| epoch 8 | 10000/13484 batches | lr 3.49 | ms/batch 115.32 | loss 4.89 | ppl 132.80\n",
"| epoch 8 | 10200/13484 batches | lr 3.49 | ms/batch 115.33 | loss 4.85 | ppl 128.25\n",
"| epoch 8 | 10400/13484 batches | lr 3.49 | ms/batch 115.35 | loss 4.83 | ppl 124.93\n",
"| epoch 8 | 10600/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.87 | ppl 130.59\n",
"| epoch 8 | 10800/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.90 | ppl 133.78\n",
"| epoch 8 | 11000/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.90 | ppl 133.75\n",
"| epoch 8 | 11200/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.89 | ppl 133.33\n",
"| epoch 8 | 11400/13484 batches | lr 3.49 | ms/batch 115.36 | loss 4.84 | ppl 126.25\n",
"| epoch 8 | 11600/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.88 | ppl 131.70\n",
"| epoch 8 | 11800/13484 batches | lr 3.49 | ms/batch 115.36 | loss 4.84 | ppl 127.09\n",
"| epoch 8 | 12000/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.89 | ppl 133.44\n",
"| epoch 8 | 12200/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.83 | ppl 124.78\n",
"| epoch 8 | 12400/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.84 | ppl 125.91\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 8 | 12600/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.86 | ppl 128.83\n",
"| epoch 8 | 12800/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.88 | ppl 131.60\n",
"| epoch 8 | 13000/13484 batches | lr 3.49 | ms/batch 115.33 | loss 4.87 | ppl 130.10\n",
"| epoch 8 | 13200/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.88 | ppl 131.87\n",
"| epoch 8 | 13400/13484 batches | lr 3.49 | ms/batch 115.40 | loss 4.90 | ppl 134.29\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 8 | time: 1626.66s | valid loss 5.00 | valid ppl 148.39\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 9 | 200/13484 batches | lr 3.32 | ms/batch 115.97 | loss 4.92 | ppl 136.72\n",
"| epoch 9 | 400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.88 | ppl 131.62\n",
"| epoch 9 | 600/13484 batches | lr 3.32 | ms/batch 115.30 | loss 4.85 | ppl 128.22\n",
"| epoch 9 | 800/13484 batches | lr 3.32 | ms/batch 115.29 | loss 4.86 | ppl 128.84\n",
"| epoch 9 | 1000/13484 batches | lr 3.32 | ms/batch 115.29 | loss 4.86 | ppl 129.65\n",
"| epoch 9 | 1200/13484 batches | lr 3.32 | ms/batch 115.21 | loss 4.86 | ppl 128.93\n",
"| epoch 9 | 1400/13484 batches | lr 3.32 | ms/batch 115.28 | loss 4.85 | ppl 127.80\n",
"| epoch 9 | 1600/13484 batches | lr 3.32 | ms/batch 115.36 | loss 4.89 | ppl 132.74\n",
"| epoch 9 | 1800/13484 batches | lr 3.32 | ms/batch 115.27 | loss 4.88 | ppl 131.14\n",
"| epoch 9 | 2000/13484 batches | lr 3.32 | ms/batch 115.32 | loss 4.84 | ppl 126.60\n",
"| epoch 9 | 2200/13484 batches | lr 3.32 | ms/batch 115.33 | loss 4.84 | ppl 126.74\n",
"| epoch 9 | 2400/13484 batches | lr 3.32 | ms/batch 115.32 | loss 4.84 | ppl 127.02\n",
"| epoch 9 | 2600/13484 batches | lr 3.32 | ms/batch 115.31 | loss 4.84 | ppl 126.21\n",
"| epoch 9 | 2800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.87 | ppl 130.53\n",
"| epoch 9 | 3000/13484 batches | lr 3.32 | ms/batch 115.31 | loss 4.85 | ppl 127.68\n",
"| epoch 9 | 3200/13484 batches | lr 3.32 | ms/batch 115.30 | loss 4.83 | ppl 125.33\n",
"| epoch 9 | 3400/13484 batches | lr 3.32 | ms/batch 115.26 | loss 4.89 | ppl 133.40\n",
"| epoch 9 | 3600/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.86 | ppl 129.20\n",
"| epoch 9 | 3800/13484 batches | lr 3.32 | ms/batch 115.39 | loss 4.85 | ppl 127.67\n",
"| epoch 9 | 4000/13484 batches | lr 3.32 | ms/batch 115.39 | loss 4.80 | ppl 121.75\n",
"| epoch 9 | 4200/13484 batches | lr 3.32 | ms/batch 115.30 | loss 4.83 | ppl 125.31\n",
"| epoch 9 | 4400/13484 batches | lr 3.32 | ms/batch 115.44 | loss 4.84 | ppl 126.39\n",
"| epoch 9 | 4600/13484 batches | lr 3.32 | ms/batch 115.29 | loss 4.88 | ppl 131.31\n",
"| epoch 9 | 4800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.86 | ppl 129.65\n",
"| epoch 9 | 5000/13484 batches | lr 3.32 | ms/batch 115.37 | loss 4.85 | ppl 128.14\n",
"| epoch 9 | 5200/13484 batches | lr 3.32 | ms/batch 115.37 | loss 4.85 | ppl 128.35\n",
"| epoch 9 | 5400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.83 | ppl 124.69\n",
"| epoch 9 | 5600/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.85 | ppl 127.26\n",
"| epoch 9 | 5800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.85 | ppl 127.16\n",
"| epoch 9 | 6000/13484 batches | lr 3.32 | ms/batch 115.45 | loss 4.86 | ppl 128.86\n",
"| epoch 9 | 6200/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.88 | ppl 131.85\n",
"| epoch 9 | 6400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.86 | ppl 129.32\n",
"| epoch 9 | 6600/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.89 | ppl 132.53\n",
"| epoch 9 | 6800/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.83 | ppl 125.33\n",
"| epoch 9 | 7000/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.88 | ppl 131.81\n",
"| epoch 9 | 7200/13484 batches | lr 3.32 | ms/batch 115.35 | loss 4.85 | ppl 127.28\n",
"| epoch 9 | 7400/13484 batches | lr 3.32 | ms/batch 115.42 | loss 4.85 | ppl 127.94\n",
"| epoch 9 | 7600/13484 batches | lr 3.32 | ms/batch 115.42 | loss 4.85 | ppl 127.51\n",
"| epoch 9 | 7800/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.85 | ppl 127.59\n",
"| epoch 9 | 8000/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.84 | ppl 126.24\n",
"| epoch 9 | 8200/13484 batches | lr 3.32 | ms/batch 115.46 | loss 4.82 | ppl 124.14\n",
"| epoch 9 | 8400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.85 | ppl 127.32\n",
"| epoch 9 | 8600/13484 batches | lr 3.32 | ms/batch 115.37 | loss 4.86 | ppl 128.81\n",
"| epoch 9 | 8800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.83 | ppl 125.56\n",
"| epoch 9 | 9000/13484 batches | lr 3.32 | ms/batch 115.35 | loss 4.85 | ppl 128.24\n",
"| epoch 9 | 9200/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.87 | ppl 130.12\n",
"| epoch 9 | 9400/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.86 | ppl 129.31\n",
"| epoch 9 | 9600/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.84 | ppl 126.04\n",
"| epoch 9 | 9800/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.81 | ppl 122.88\n",
"| epoch 9 | 10000/13484 batches | lr 3.32 | ms/batch 115.43 | loss 4.84 | ppl 126.54\n",
"| epoch 9 | 10200/13484 batches | lr 3.32 | ms/batch 115.43 | loss 4.80 | ppl 121.48\n",
"| epoch 9 | 10400/13484 batches | lr 3.32 | ms/batch 115.35 | loss 4.78 | ppl 118.65\n",
"| epoch 9 | 10600/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.83 | ppl 124.64\n",
"| epoch 9 | 10800/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.85 | ppl 127.13\n",
"| epoch 9 | 11000/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.85 | ppl 127.77\n",
"| epoch 9 | 11200/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.84 | ppl 126.57\n",
"| epoch 9 | 11400/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.79 | ppl 120.87\n",
"| epoch 9 | 11600/13484 batches | lr 3.32 | ms/batch 115.33 | loss 4.83 | ppl 125.52\n",
"| epoch 9 | 11800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.80 | ppl 120.94\n",
"| epoch 9 | 12000/13484 batches | lr 3.32 | ms/batch 115.43 | loss 4.85 | ppl 127.12\n",
"| epoch 9 | 12200/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.78 | ppl 119.42\n",
"| epoch 9 | 12400/13484 batches | lr 3.32 | ms/batch 115.45 | loss 4.79 | ppl 120.40\n",
"| epoch 9 | 12600/13484 batches | lr 3.32 | ms/batch 115.44 | loss 4.82 | ppl 124.30\n",
"| epoch 9 | 12800/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.84 | ppl 126.27\n",
"| epoch 9 | 13000/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.83 | ppl 125.13\n",
"| epoch 9 | 13200/13484 batches | lr 3.32 | ms/batch 115.39 | loss 4.83 | ppl 125.83\n",
"| epoch 9 | 13400/13484 batches | lr 3.32 | ms/batch 115.42 | loss 4.85 | ppl 128.06\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 9 | time: 1628.05s | valid loss 5.02 | valid ppl 150.80\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 10 | 200/13484 batches | lr 3.15 | ms/batch 116.03 | loss 4.88 | ppl 131.35\n",
"| epoch 10 | 400/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.84 | ppl 126.70\n",
"| epoch 10 | 600/13484 batches | lr 3.15 | ms/batch 115.47 | loss 4.82 | ppl 124.18\n",
"| epoch 10 | 800/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.82 | ppl 123.47\n",
"| epoch 10 | 1000/13484 batches | lr 3.15 | ms/batch 115.31 | loss 4.82 | ppl 124.52\n",
"| epoch 10 | 1200/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.83 | ppl 124.69\n",
"| epoch 10 | 1400/13484 batches | lr 3.15 | ms/batch 115.49 | loss 4.81 | ppl 122.50\n",
"| epoch 10 | 1600/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.85 | ppl 127.35\n",
"| epoch 10 | 1800/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.83 | ppl 124.97\n",
"| epoch 10 | 2000/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.80 | ppl 121.45\n",
"| epoch 10 | 2200/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.80 | ppl 121.97\n",
"| epoch 10 | 2400/13484 batches | lr 3.15 | ms/batch 115.37 | loss 4.80 | ppl 122.05\n",
"| epoch 10 | 2600/13484 batches | lr 3.15 | ms/batch 115.46 | loss 4.79 | ppl 120.16\n",
"| epoch 10 | 2800/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.83 | ppl 125.44\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 10 | 3000/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.81 | ppl 122.12\n",
"| epoch 10 | 3200/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.79 | ppl 120.18\n",
"| epoch 10 | 3400/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.84 | ppl 127.05\n",
"| epoch 10 | 3600/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.82 | ppl 123.70\n",
"| epoch 10 | 3800/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.80 | ppl 121.74\n",
"| epoch 10 | 4000/13484 batches | lr 3.15 | ms/batch 115.49 | loss 4.76 | ppl 116.53\n",
"| epoch 10 | 4200/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.78 | ppl 119.64\n",
"| epoch 10 | 4400/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.80 | ppl 121.17\n",
"| epoch 10 | 4600/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.84 | ppl 126.61\n",
"| epoch 10 | 4800/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.83 | ppl 124.71\n",
"| epoch 10 | 5000/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.81 | ppl 122.89\n",
"| epoch 10 | 5200/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.81 | ppl 123.00\n",
"| epoch 10 | 5400/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.79 | ppl 120.50\n",
"| epoch 10 | 5600/13484 batches | lr 3.15 | ms/batch 115.38 | loss 4.80 | ppl 121.56\n",
"| epoch 10 | 5800/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.80 | ppl 121.20\n",
"| epoch 10 | 6000/13484 batches | lr 3.15 | ms/batch 115.38 | loss 4.82 | ppl 123.72\n",
"| epoch 10 | 6200/13484 batches | lr 3.15 | ms/batch 115.35 | loss 4.85 | ppl 127.61\n",
"| epoch 10 | 6400/13484 batches | lr 3.15 | ms/batch 115.32 | loss 4.82 | ppl 124.04\n",
"| epoch 10 | 6600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.85 | ppl 127.34\n",
"| epoch 10 | 6800/13484 batches | lr 3.15 | ms/batch 115.38 | loss 4.80 | ppl 121.21\n",
"| epoch 10 | 7000/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.84 | ppl 126.43\n",
"| epoch 10 | 7200/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.81 | ppl 122.41\n",
"| epoch 10 | 7400/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.81 | ppl 122.46\n",
"| epoch 10 | 7600/13484 batches | lr 3.15 | ms/batch 115.37 | loss 4.80 | ppl 122.05\n",
"| epoch 10 | 7800/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.80 | ppl 121.58\n",
"| epoch 10 | 8000/13484 batches | lr 3.15 | ms/batch 115.32 | loss 4.79 | ppl 120.04\n",
"| epoch 10 | 8200/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.78 | ppl 118.97\n",
"| epoch 10 | 8400/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.80 | ppl 121.55\n",
"| epoch 10 | 8600/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.82 | ppl 123.48\n",
"| epoch 10 | 8800/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.79 | ppl 119.98\n",
"| epoch 10 | 9000/13484 batches | lr 3.15 | ms/batch 115.37 | loss 4.80 | ppl 121.60\n",
"| epoch 10 | 9200/13484 batches | lr 3.15 | ms/batch 115.50 | loss 4.82 | ppl 124.26\n",
"| epoch 10 | 9400/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.83 | ppl 124.68\n",
"| epoch 10 | 9600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.79 | ppl 120.72\n",
"| epoch 10 | 9800/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.77 | ppl 118.27\n",
"| epoch 10 | 10000/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.80 | ppl 121.83\n",
"| epoch 10 | 10200/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.76 | ppl 116.50\n",
"| epoch 10 | 10400/13484 batches | lr 3.15 | ms/batch 115.34 | loss 4.74 | ppl 114.00\n",
"| epoch 10 | 10600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.79 | ppl 119.78\n",
"| epoch 10 | 10800/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.80 | ppl 121.63\n",
"| epoch 10 | 11000/13484 batches | lr 3.15 | ms/batch 115.32 | loss 4.80 | ppl 121.41\n",
"| epoch 10 | 11200/13484 batches | lr 3.15 | ms/batch 115.52 | loss 4.80 | ppl 121.11\n",
"| epoch 10 | 11400/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.75 | ppl 115.26\n",
"| epoch 10 | 11600/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.79 | ppl 120.63\n",
"| epoch 10 | 11800/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.75 | ppl 115.77\n",
"| epoch 10 | 12000/13484 batches | lr 3.15 | ms/batch 115.48 | loss 4.80 | ppl 121.70\n",
"| epoch 10 | 12200/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.74 | ppl 114.59\n",
"| epoch 10 | 12400/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.76 | ppl 116.79\n",
"| epoch 10 | 12600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.77 | ppl 118.25\n",
"| epoch 10 | 12800/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.79 | ppl 120.47\n",
"| epoch 10 | 13000/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.78 | ppl 119.53\n",
"| epoch 10 | 13200/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.79 | ppl 120.28\n",
"| epoch 10 | 13400/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.81 | ppl 122.17\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 10 | time: 1628.47s | valid loss 4.98 | valid ppl 145.33\n",
"-----------------------------------------------------------------------------------------\n"
]
}
],
"source": [
"best_val_loss = float('inf')\n",
"epochs = 10\n",
"best_model = None\n",
"\n",
"for epoch in range(1, epochs + 1):\n",
" epoch_start_time = time.time()\n",
" train(model)\n",
" val_loss = evaluate(model, val_data)\n",
" val_ppl = math.exp(val_loss)\n",
" elapsed = time.time() - epoch_start_time\n",
" print('-' * 89)\n",
" print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '\n",
" f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')\n",
" print('-' * 89)\n",
"\n",
" if val_loss < best_val_loss:\n",
" best_val_loss = val_loss\n",
" best_model = copy.deepcopy(model)\n",
"\n",
" scheduler.step()"
]
},
{
"cell_type": "markdown",
"id": "f0d32419",
"metadata": {},
"source": [
"### print info about best model after training"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "12fdd0aa",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=========================================================================================\n",
"| End of training | test loss 4.98 | test ppl 144.89\n",
"=========================================================================================\n"
]
}
],
"source": [
"test_loss = evaluate(best_model, test_data)\n",
"test_ppl = math.exp(test_loss)\n",
"print('=' * 89)\n",
"print(f'| End of training | test loss {test_loss:5.2f} | '\n",
" f'test ppl {test_ppl:8.2f}')\n",
"print('=' * 89)"
]
},
{
"cell_type": "markdown",
"id": "528c9f10",
"metadata": {},
"source": [
"### save trained model to file"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "848af399",
"metadata": {},
"outputs": [],
"source": [
"torch.save(best_model, \"pubmed-sentencecomplete.pt\")"
]
},
{
"cell_type": "markdown",
"id": "09df56cf",
"metadata": {},
"source": [
"## Now we can try to predict based on trained model"
]
},
{
"cell_type": "markdown",
"id": "fe250072",
"metadata": {},
"source": [
"### obtain iterator for predict batch "
]
},
{
"cell_type": "code",
"execution_count": 159,
"id": "afe585d6",
"metadata": {},
"outputs": [],
"source": [
"def predict_abstract_iter(batch):\n",
" for batch in batch:\n",
" yield tokenizer(batch)"
]
},
{
"cell_type": "markdown",
"id": "b043de0a",
"metadata": {},
"source": [
"### load data into tensor for model to process"
]
},
{
"cell_type": "code",
"execution_count": 154,
"id": "8bfaa8bd",
"metadata": {},
"outputs": [],
"source": [
"def toDataTensor(batch):\n",
" predict_generator = predict_abstract_iter(batch)\n",
" return [torch.tensor(vocab.lookup_indices(item)) for item in predict_generator]"
]
},
{
"cell_type": "markdown",
"id": "a800ffea",
"metadata": {},
"source": [
"### check device once again (prob not needed)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6e2c35ba",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'torch' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m device \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 2\u001b[0m device\n",
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
]
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"device"
]
},
{
"cell_type": "markdown",
"id": "bef90722",
"metadata": {},
"source": [
"### optionally load model from file if it was trained already"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "223eed8a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TransformerModel(\n",
" (pos_encoder): PositionalEncoding(\n",
" (dropout): Dropout(p=0.2, inplace=False)\n",
" )\n",
" (transformer_encoder): TransformerEncoder(\n",
" (layers): ModuleList(\n",
" (0): TransformerEncoderLayer(\n",
" (self_attn): MultiheadAttention(\n",
" (out_proj): NonDynamicallyQuantizableLinear(in_features=200, out_features=200, bias=True)\n",
" )\n",
" (linear1): Linear(in_features=200, out_features=200, bias=True)\n",
" (dropout): Dropout(p=0.2, inplace=False)\n",
" (linear2): Linear(in_features=200, out_features=200, bias=True)\n",
" (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
" (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
" (dropout1): Dropout(p=0.2, inplace=False)\n",
" (dropout2): Dropout(p=0.2, inplace=False)\n",
" )\n",
" (1): TransformerEncoderLayer(\n",
" (self_attn): MultiheadAttention(\n",
" (out_proj): NonDynamicallyQuantizableLinear(in_features=200, out_features=200, bias=True)\n",
" )\n",
" (linear1): Linear(in_features=200, out_features=200, bias=True)\n",
" (dropout): Dropout(p=0.2, inplace=False)\n",
" (linear2): Linear(in_features=200, out_features=200, bias=True)\n",
" (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
" (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
" (dropout1): Dropout(p=0.2, inplace=False)\n",
" (dropout2): Dropout(p=0.2, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (encoder): Embedding(163987, 200)\n",
" (decoder): Linear(in_features=200, out_features=163987, bias=True)\n",
")"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_model = torch.load(\"pubmed-sentencecomplete.pt\")\n",
"best_model.eval()"
]
},
{
"cell_type": "markdown",
"id": "dd71bdfc",
"metadata": {},
"source": [
"### define predict function"
]
},
{
"cell_type": "code",
"execution_count": 160,
"id": "64223e87",
"metadata": {},
"outputs": [],
"source": [
"def predict(input_line, mask, n_predictions=3):\n",
" with torch.no_grad():\n",
" output = best_model(input_line.to(device), mask) \n",
" predictions = []\n",
" for i in range(n_predictions):\n",
" next_item = output.topk(i+1)[1].view(-1)[-1].item()\n",
" predict_token_index = next_item\n",
" predictions.append(vocab.lookup_token(predict_token_index))\n",
" \n",
" return predictions"
]
},
{
"cell_type": "markdown",
"id": "a9b7311b",
"metadata": {},
"source": [
"### define input batch "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "913628b4",
"metadata": {},
"outputs": [],
"source": [
"sample_batch = [\n",
" \"There is\"\n",
"]\n",
"input_batch = sample_batch"
]
},
{
"cell_type": "markdown",
"id": "45930a71",
"metadata": {},
"source": [
"### define initial source mask for model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c4bba2a1",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'generate_square_subsequent_mask' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m bptt \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 2\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(bptt)\u001b[38;5;241m.\u001b[39mto(device)\n",
"\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined"
]
}
],
"source": [
"bptt = 2\n",
"src_mask = generate_square_subsequent_mask(bptt).to(device)"
]
},
{
"cell_type": "markdown",
"id": "5b33b9f3",
"metadata": {},
"source": [
"### Execute prediction and display predicted values and choose continuation"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b2895698",
"metadata": {},
"outputs": [],
"source": [
"def predict_loop(num_of_pred):\n",
" iteration = 0\n",
" is_terminated = False\n",
" input_batch = sample_batch\n",
" while(not is_terminated):\n",
" mask_size = bptt+(iteration) \n",
" src_mask = generate_square_subsequent_mask(mask_size).to(device)\n",
" data = toDataTensor(input_batch)\n",
" \n",
" for i, d in enumerate(data):\n",
" predictions = predict(d, src_mask, num_of_pred)\n",
" \n",
" print(\"\\n Possible continuations:\")\n",
" for j in range(len(predictions)):\n",
" print(j + 1, \": \", predictions[j])\n",
" s_index = input(input_batch[i])\n",
" if(\"e\" in s_index):\n",
" is_terminated = True\n",
" print(\"prediction stopped.\")\n",
" break\n",
"\n",
" print(\"Text is now:\")\n",
" input_batch[i] += (\" \" + predictions[int(s_index) - 1])\n",
" print(input_batch[i])\n",
"\n",
" iteration = iteration + 1"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "13ed9298",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'generate_square_subsequent_mask' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mpredict_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn [4], line 7\u001b[0m, in \u001b[0;36mpredict_loop\u001b[0;34m(num_of_pred)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m(\u001b[38;5;129;01mnot\u001b[39;00m is_terminated):\n\u001b[1;32m 6\u001b[0m mask_size \u001b[38;5;241m=\u001b[39m bptt\u001b[38;5;241m+\u001b[39m(iteration) \n\u001b[0;32m----> 7\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(mask_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m toDataTensor(input_batch)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, d \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(data):\n",
"\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined"
]
}
],
"source": [
"predict_loop(3)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}