{ "cells": [ { "cell_type": "markdown", "id": "622dfcd6", "metadata": {}, "source": [ "# Group 09 - Automatic Sentence Completion for PubMed" ] }, { "cell_type": "markdown", "id": "5e4cec3c", "metadata": {}, "source": [ "### Authors:\n", "- Constantin Fuerst\t\n", "- Leonard Starke" ] }, { "cell_type": "markdown", "id": "806cfb27", "metadata": {}, "source": [ "### link to \"Attention is All You Need\" paper describing transformer models" ] }, { "cell_type": "code", "execution_count": null, "id": "fe862072", "metadata": {}, "outputs": [], "source": [ "https://arxiv.org/pdf/1706.03762.pdf" ] }, { "cell_type": "markdown", "id": "fa161b1b", "metadata": {}, "source": [ "### load query data from text file " ] }, { "cell_type": "code", "execution_count": 2, "id": "e1912a79", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2023-01-18 15:48:24-- https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download\n", "Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'\n", "Resolving cloud.constantin-fuerst.com (cloud.constantin-fuerst.com)... 95.91.21.14\n", "Connecting to cloud.constantin-fuerst.com (cloud.constantin-fuerst.com)|95.91.21.14|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 1100551 (1.0M) [text/plain]\n", "Saving to: ‘pubmed-query.txt’\n", "\n", "pubmed-query.txt 100%[===================>] 1.05M 1.61MB/s in 0.7s \n", "\n", "2023-01-18 15:48:25 (1.61 MB/s) - ‘pubmed-query.txt’ saved [1100551/1100551]\n", "\n" ] } ], "source": [ "!wget https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download -O pubmed-query.txt" ] }, { "cell_type": "markdown", "id": "da068411", "metadata": {}, "source": [ "### import modules used for parsing query data" ] }, { "cell_type": "code", "execution_count": null, "id": "c10bc5a8", "metadata": {}, "outputs": [], "source": [ "try:\n", " from Bio import Medline\n", "except: \n", " !pip install Bio\n", " from Bio import Medline" ] }, { "cell_type": "markdown", "id": "7bf15c30", "metadata": {}, "source": [ "### define function for loading the papers from PubMed database" ] }, { "cell_type": "code", "execution_count": 1, "id": "adfb256a", "metadata": {}, "outputs": [], "source": [ "def getPapers(filename):\n", " pubmed_query = open(filename, encoding='utf-8')\n", " records = Medline.parse(pubmed_query)\n", " return list(records)" ] }, { "cell_type": "markdown", "id": "46bc6298", "metadata": {}, "source": [ "### Verify that its working" ] }, { "cell_type": "code", "execution_count": 4, "id": "00481ec9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Got 150000 records from the query text file\n" ] } ], "source": [ "max_records = 150000\n", "records = getPapers(\"pubmed-query.txt\")\n", "records = records[:min(max_records, len(records))]\n", "print(f\"Got {len(records)} records from the query text file\")" ] }, { "cell_type": "markdown", "id": "b67747c6", "metadata": {}, "source": [ "### Now extract abstracts from records" ] }, { "cell_type": "code", "execution_count": 5, "id": "dcf5c217", "metadata": {}, "outputs": [], "source": [ "r_abstracts = []\n", "for r in records:\n", " if not (r.get('AB') is None):\n", " r_abstracts.append(r['AB'])" ] }, { "cell_type": "markdown", "id": "e309f6fe", "metadata": {}, "source": [ "### Now import torch modules needed to load the data" ] }, { "cell_type": "code", "execution_count": 6, "id": "c3199444", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/hein/.local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "try:\n", " import torch\n", " from torch.utils.data import Dataset \n", " from torchtext.data import get_tokenizer\n", "except:\n", " !pip --default-timeout=1000 install torch\n", " !pip --default-timeout=1000 install torchtext\n", " import torch\n", " from torch.utils.data import Dataset \n", " from torchtext.data import get_tokenizer" ] }, { "cell_type": "markdown", "id": "5b4007e8", "metadata": {}, "source": [ "### Import numpy" ] }, { "cell_type": "code", "execution_count": 7, "id": "daca9db6", "metadata": {}, "outputs": [], "source": [ "try:\n", " import numpy as np\n", "except:\n", " !pip install numpy\n", " import numpy as np\n" ] }, { "cell_type": "markdown", "id": "683ed2fc", "metadata": {}, "source": [ "### import math module" ] }, { "cell_type": "code", "execution_count": 8, "id": "8d2312db", "metadata": {}, "outputs": [], "source": [ "import math" ] }, { "cell_type": "markdown", "id": "4df1e449", "metadata": {}, "source": [ "### define token iterators" ] }, { "cell_type": "code", "execution_count": 9, "id": "3f23404d", "metadata": {}, "outputs": [], "source": [ "train_size = math.floor(len(r_abstracts) * 0.75)\n", "val_size = math.floor(len(r_abstracts) * 0.125)\n", "test_size = math.floor(len(r_abstracts) * 0.125)" ] }, { "cell_type": "code", "execution_count": 10, "id": "8a128d3c", "metadata": {}, "outputs": [], "source": [ "def train_abstract_iter():\n", " for abstract in r_abstracts[:train_size]:\n", " yield abstract" ] }, { "cell_type": "code", "execution_count": 11, "id": "97e89986", "metadata": {}, "outputs": [], "source": [ "def val_abstract_iter():\n", " for abstract in r_abstracts[(train_size + 1):(train_size + val_size)]:\n", " yield abstract" ] }, { "cell_type": "code", "execution_count": 12, "id": "0d6e89c4", "metadata": {}, "outputs": [], "source": [ "def test_abstract_iter():\n", " for abstract in r_abstracts[(train_size + val_size + 1): (train_size + val_size + test_size)]:\n", " yield abstract" ] }, { "cell_type": "markdown", "id": "e5e9c5a2", "metadata": {}, "source": [ "### define Tokenize function" ] }, { "cell_type": "code", "execution_count": 13, "id": "0bdbc40a", "metadata": {}, "outputs": [], "source": [ "tokenizer = get_tokenizer(\"basic_english\")\n", "def tokenize_abstract_iter():\n", " for abstract in r_abstracts:\n", " yield tokenizer(abstract)" ] }, { "cell_type": "markdown", "id": "37da40bb", "metadata": {}, "source": [ "### Map every word to an id to store inside torch tensor" ] }, { "cell_type": "code", "execution_count": 14, "id": "a438ab1f", "metadata": {}, "outputs": [], "source": [ "from torchtext.vocab import build_vocab_from_iterator\n", "token_generator = tokenize_abstract_iter()\n", "vocab = build_vocab_from_iterator(token_generator, specials=[''])\n", "vocab.set_default_index(vocab[''])\n" ] }, { "cell_type": "markdown", "id": "221bdc48", "metadata": {}, "source": [ "### now convert to tensor\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "0e5bc361", "metadata": {}, "outputs": [], "source": [ "def data_process(tokens_iter):\n", " \"\"\"Converts raw text into a flat Tensor.\"\"\"\n", " data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in tokens_iter]\n", " return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))" ] }, { "cell_type": "code", "execution_count": 16, "id": "dfd7400d", "metadata": {}, "outputs": [], "source": [ "train_generator = train_abstract_iter()\n", "val_generator = val_abstract_iter()\n", "test_generator = test_abstract_iter()\n", "train_data = data_process(train_generator)\n", "val_data = data_process(val_generator)\n", "test_data = data_process(test_generator)" ] }, { "cell_type": "markdown", "id": "c49a2734", "metadata": {}, "source": [ "### check gpu" ] }, { "cell_type": "code", "execution_count": 20, "id": "c155ee31", "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 21, "id": "79b2d248", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device" ] }, { "cell_type": "markdown", "id": "2150ba71", "metadata": {}, "source": [ "### define model" ] }, { "cell_type": "code", "execution_count": 22, "id": "a33d722f", "metadata": {}, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "from torch import nn, Tensor\n", "import torch.nn.functional as F\n", "from torch.nn import TransformerEncoder, TransformerEncoderLayer\n", "from torch.utils.data import dataset\n", "\n", "class TransformerModel(nn.Module):\n", "\n", " def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,\n", " nlayers: int, dropout: float = 0.5):\n", " super().__init__()\n", " self.model_type = 'Transformer'\n", " self.pos_encoder = PositionalEncoding(d_model, dropout)\n", " encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)\n", " self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n", " self.encoder = nn.Embedding(ntoken, d_model)\n", " self.d_model = d_model\n", " self.decoder = nn.Linear(d_model, ntoken)\n", "\n", " self.init_weights()\n", "\n", " def init_weights(self) -> None:\n", " initrange = 0.1\n", " self.encoder.weight.data.uniform_(-initrange, initrange)\n", " self.decoder.bias.data.zero_()\n", " self.decoder.weight.data.uniform_(-initrange, initrange)\n", "\n", " def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:\n", " \"\"\"\n", " Args:\n", " src: Tensor, shape [seq_len, batch_size]\n", " src_mask: Tensor, shape [seq_len, seq_len]\n", "\n", " Returns:\n", " output Tensor of shape [seq_len, batch_size, ntoken]\n", " \"\"\"\n", " src = self.encoder(src) * math.sqrt(self.d_model)\n", " src = self.pos_encoder(src)\n", " output = self.transformer_encoder(src, src_mask)\n", " output = self.decoder(output)\n", " return output\n", "\n", "\n", "def generate_square_subsequent_mask(sz: int) -> Tensor:\n", " \"\"\"Generates an upper-triangular matrix of -inf, with zeros on diag.\"\"\"\n", " return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)" ] }, { "cell_type": "markdown", "id": "23268efe", "metadata": {}, "source": [ "### define pos encoder" ] }, { "cell_type": "code", "execution_count": 23, "id": "c2f6d33b", "metadata": {}, "outputs": [], "source": [ "class PositionalEncoding(nn.Module):\n", "\n", " def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):\n", " super().__init__()\n", " self.dropout = nn.Dropout(p=dropout)\n", "\n", " position = torch.arange(max_len).unsqueeze(1)\n", " div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n", " pe = torch.zeros(max_len, 1, d_model)\n", " pe[:, 0, 0::2] = torch.sin(position * div_term)\n", " pe[:, 0, 1::2] = torch.cos(position * div_term)\n", " self.register_buffer('pe', pe)\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Args:\n", " x: Tensor, shape [seq_len, batch_size, embedding_dim]\n", " \"\"\"\n", " x = x + self.pe[:x.size(0)]\n", " return self.dropout(x)\n" ] }, { "cell_type": "markdown", "id": "306352f5", "metadata": {}, "source": [ "### define function to create batches of data and create batches" ] }, { "cell_type": "code", "execution_count": 24, "id": "9e184841", "metadata": {}, "outputs": [], "source": [ "def batchify(data: Tensor, bsz: int) -> Tensor:\n", " \"\"\"Divides the data into bsz separate sequences, removing extra elements\n", " that wouldn't cleanly fit.\n", "\n", " Args:\n", " data: Tensor, shape [N]\n", " bsz: int, batch size\n", "\n", " Returns:\n", " Tensor of shape [N // bsz, bsz]\n", " \"\"\"\n", " seq_len = data.size(0) // bsz\n", " data = data[:seq_len * bsz]\n", " data = data.view(bsz, seq_len).t().contiguous()\n", " return data.to(device)" ] }, { "cell_type": "code", "execution_count": 25, "id": "a4def1ac", "metadata": {}, "outputs": [], "source": [ "batch_size = 20\n", "eval_batch_size = 10\n", "train_data = batchify(train_data, batch_size) # shape [seq_len, batch_size]\n", "val_data = batchify(val_data, eval_batch_size)\n", "test_data = batchify(test_data, eval_batch_size)" ] }, { "cell_type": "markdown", "id": "c658cb42", "metadata": {}, "source": [ "### define function to get batch" ] }, { "cell_type": "code", "execution_count": 26, "id": "4ab5b8fd", "metadata": {}, "outputs": [], "source": [ "bptt = 35\n", "def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:\n", " \"\"\"\n", " Args:\n", " source: Tensor, shape [full_seq_len, batch_size]\n", " i: int\n", "\n", " Returns:\n", " tuple (data, target), where data has shape [seq_len, batch_size] and\n", " target has shape [seq_len * batch_size]\n", " \"\"\"\n", " seq_len = min(bptt, len(source) - 1 - i)\n", " data = source[i:i+seq_len]\n", " target = source[i+1:i+1+seq_len].reshape(-1)\n", " return data, target" ] }, { "cell_type": "markdown", "id": "d6392484", "metadata": {}, "source": [ "### define parameters and init model" ] }, { "cell_type": "code", "execution_count": 27, "id": "c53764da", "metadata": {}, "outputs": [], "source": [ "ntokens = len(vocab) # size of vocabulary\n", "emsize = 200 # embedding dimension\n", "d_hid = 200 # dimension of the feedforward network model in nn.TransformerEncoder\n", "nlayers = 2 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder\n", "nhead = 2 # number of heads in nn.MultiheadAttention\n", "dropout = 0.2 # dropout probability\n", "model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)" ] }, { "cell_type": "markdown", "id": "7fb67d72", "metadata": {}, "source": [ "### init optimizer, loss, scheduler etc." ] }, { "cell_type": "code", "execution_count": 28, "id": "ddaa1d64", "metadata": {}, "outputs": [], "source": [ "import copy\n", "import time\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "lr = 5.0 # learning rate\n", "optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)" ] }, { "cell_type": "markdown", "id": "dda19446", "metadata": {}, "source": [ "### define train function" ] }, { "cell_type": "code", "execution_count": 29, "id": "50ab3fb6", "metadata": {}, "outputs": [], "source": [ "def train(model: nn.Module) -> None:\n", " model.train() # turn on train mode\n", " total_loss = 0.\n", " log_interval = 200\n", " start_time = time.time()\n", " src_mask = generate_square_subsequent_mask(bptt).to(device)\n", "\n", " num_batches = len(train_data) // bptt\n", " for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n", " data, targets = get_batch(train_data, i)\n", " seq_len = data.size(0)\n", " if seq_len != bptt: # only on last batch\n", " src_mask = src_mask[:seq_len, :seq_len]\n", " output = model(data, src_mask)\n", " loss = criterion(output.view(-1, ntokens), targets)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n", " optimizer.step()\n", "\n", " total_loss += loss.item()\n", " if batch % log_interval == 0 and batch > 0:\n", " lr = scheduler.get_last_lr()[0]\n", " ms_per_batch = (time.time() - start_time) * 1000 / log_interval\n", " cur_loss = total_loss / log_interval\n", " ppl = math.exp(cur_loss)\n", " print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '\n", " f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n", " f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n", " total_loss = 0\n", " start_time = time.time()" ] }, { "cell_type": "markdown", "id": "9756c092", "metadata": {}, "source": [ "### define evaluate function" ] }, { "cell_type": "code", "execution_count": 30, "id": "3d179bb0", "metadata": {}, "outputs": [], "source": [ "def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n", " model.eval() # turn on evaluation mode\n", " total_loss = 0.\n", " src_mask = generate_square_subsequent_mask(bptt).to(device)\n", " with torch.no_grad():\n", " for i in range(0, eval_data.size(0) - 1, bptt):\n", " data, targets = get_batch(eval_data, i)\n", " seq_len = data.size(0)\n", " if seq_len != bptt:\n", " src_mask = src_mask[:seq_len, :seq_len]\n", " output = model(data, src_mask)\n", " output_flat = output.view(-1, ntokens)\n", " total_loss += seq_len * criterion(output_flat, targets).item()\n", " return total_loss / (len(eval_data) - 1)" ] }, { "cell_type": "markdown", "id": "5a959f09", "metadata": {}, "source": [ "### now we can start training the model while saving best one" ] }, { "cell_type": "code", "execution_count": 31, "id": "09c4d4ce", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "| epoch 1 | 200/13484 batches | lr 5.00 | ms/batch 116.24 | loss 9.27 | ppl 10651.55\n", "| epoch 1 | 400/13484 batches | lr 5.00 | ms/batch 114.02 | loss 7.49 | ppl 1787.62\n", "| epoch 1 | 600/13484 batches | lr 5.00 | ms/batch 114.33 | loss 6.83 | ppl 923.44\n", "| epoch 1 | 800/13484 batches | lr 5.00 | ms/batch 114.54 | loss 6.54 | ppl 693.98\n", "| epoch 1 | 1000/13484 batches | lr 5.00 | ms/batch 114.73 | loss 6.33 | ppl 563.29\n", "| epoch 1 | 1200/13484 batches | lr 5.00 | ms/batch 114.85 | loss 6.18 | ppl 485.05\n", "| epoch 1 | 1400/13484 batches | lr 5.00 | ms/batch 114.91 | loss 6.09 | ppl 440.69\n", "| epoch 1 | 1600/13484 batches | lr 5.00 | ms/batch 115.00 | loss 6.06 | ppl 428.38\n", "| epoch 1 | 1800/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.98 | ppl 397.07\n", "| epoch 1 | 2000/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.91 | ppl 369.13\n", "| epoch 1 | 2200/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.89 | ppl 360.14\n", "| epoch 1 | 2400/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.83 | ppl 341.10\n", "| epoch 1 | 2600/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.78 | ppl 322.33\n", "| epoch 1 | 2800/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.80 | ppl 329.27\n", "| epoch 1 | 3000/13484 batches | lr 5.00 | ms/batch 115.12 | loss 5.77 | ppl 321.64\n", "| epoch 1 | 3200/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.71 | ppl 303.37\n", "| epoch 1 | 3400/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.74 | ppl 311.04\n", "| epoch 1 | 3600/13484 batches | lr 5.00 | ms/batch 115.15 | loss 5.70 | ppl 299.44\n", "| epoch 1 | 3800/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.68 | ppl 292.67\n", "| epoch 1 | 4000/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.59 | ppl 268.70\n", "| epoch 1 | 4200/13484 batches | lr 5.00 | ms/batch 115.19 | loss 5.62 | ppl 275.23\n", "| epoch 1 | 4400/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.63 | ppl 277.51\n", "| epoch 1 | 4600/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.66 | ppl 286.99\n", "| epoch 1 | 4800/13484 batches | lr 5.00 | ms/batch 115.30 | loss 5.62 | ppl 276.08\n", "| epoch 1 | 5000/13484 batches | lr 5.00 | ms/batch 115.15 | loss 5.61 | ppl 272.68\n", "| epoch 1 | 5200/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.59 | ppl 268.83\n", "| epoch 1 | 5400/13484 batches | lr 5.00 | ms/batch 115.29 | loss 5.55 | ppl 257.80\n", "| epoch 1 | 5600/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.57 | ppl 261.32\n", "| epoch 1 | 5800/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.55 | ppl 257.06\n", "| epoch 1 | 6000/13484 batches | lr 5.00 | ms/batch 115.26 | loss 5.56 | ppl 259.08\n", "| epoch 1 | 6200/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.57 | ppl 262.89\n", "| epoch 1 | 6400/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.54 | ppl 254.66\n", "| epoch 1 | 6600/13484 batches | lr 5.00 | ms/batch 115.27 | loss 5.57 | ppl 263.01\n", "| epoch 1 | 6800/13484 batches | lr 5.00 | ms/batch 115.21 | loss 5.51 | ppl 246.13\n", "| epoch 1 | 7000/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.57 | ppl 261.50\n", "| epoch 1 | 7200/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.51 | ppl 247.48\n", "| epoch 1 | 7400/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.50 | ppl 245.45\n", "| epoch 1 | 7600/13484 batches | lr 5.00 | ms/batch 115.26 | loss 5.51 | ppl 247.79\n", "| epoch 1 | 7800/13484 batches | lr 5.00 | ms/batch 115.27 | loss 5.50 | ppl 245.74\n", "| epoch 1 | 8000/13484 batches | lr 5.00 | ms/batch 115.33 | loss 5.48 | ppl 240.49\n", "| epoch 1 | 8200/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.48 | ppl 238.87\n", "| epoch 1 | 8400/13484 batches | lr 5.00 | ms/batch 115.27 | loss 5.49 | ppl 241.45\n", "| epoch 1 | 8600/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.47 | ppl 236.88\n", "| epoch 1 | 8800/13484 batches | lr 5.00 | ms/batch 115.28 | loss 5.47 | ppl 236.31\n", "| epoch 1 | 9000/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.48 | ppl 240.63\n", "| epoch 1 | 9200/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.48 | ppl 239.53\n", "| epoch 1 | 9400/13484 batches | lr 5.00 | ms/batch 115.29 | loss 5.48 | ppl 238.75\n", "| epoch 1 | 9600/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.43 | ppl 229.14\n", "| epoch 1 | 9800/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.42 | ppl 226.49\n", "| epoch 1 | 10000/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.47 | ppl 236.79\n", "| epoch 1 | 10200/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.41 | ppl 223.98\n", "| epoch 1 | 10400/13484 batches | lr 5.00 | ms/batch 115.16 | loss 5.39 | ppl 219.63\n", "| epoch 1 | 10600/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.42 | ppl 225.37\n", "| epoch 1 | 10800/13484 batches | lr 5.00 | ms/batch 115.30 | loss 5.45 | ppl 232.44\n", "| epoch 1 | 11000/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.45 | ppl 232.12\n", "| epoch 1 | 11200/13484 batches | lr 5.00 | ms/batch 115.21 | loss 5.43 | ppl 228.71\n", "| epoch 1 | 11400/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.38 | ppl 216.73\n", "| epoch 1 | 11600/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.41 | ppl 222.68\n", "| epoch 1 | 11800/13484 batches | lr 5.00 | ms/batch 115.28 | loss 5.39 | ppl 218.39\n", "| epoch 1 | 12000/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.44 | ppl 229.94\n", "| epoch 1 | 12200/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.36 | ppl 213.26\n", "| epoch 1 | 12400/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.38 | ppl 217.41\n", "| epoch 1 | 12600/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.40 | ppl 222.35\n", "| epoch 1 | 12800/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.41 | ppl 224.63\n", "| epoch 1 | 13000/13484 batches | lr 5.00 | ms/batch 115.29 | loss 5.40 | ppl 220.79\n", "| epoch 1 | 13200/13484 batches | lr 5.00 | ms/batch 115.16 | loss 5.41 | ppl 223.58\n", "| epoch 1 | 13400/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.42 | ppl 225.49\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 1 | time: 1625.43s | valid loss 5.35 | valid ppl 210.93\n", "-----------------------------------------------------------------------------------------\n", "| epoch 2 | 200/13484 batches | lr 4.75 | ms/batch 115.84 | loss 5.44 | ppl 229.80\n", "| epoch 2 | 400/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.38 | ppl 216.74\n", "| epoch 2 | 600/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.35 | ppl 211.15\n", "| epoch 2 | 800/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.37 | ppl 215.74\n", "| epoch 2 | 1000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.35 | ppl 210.96\n", "| epoch 2 | 1200/13484 batches | lr 4.75 | ms/batch 115.17 | loss 5.33 | ppl 207.12\n", "| epoch 2 | 1400/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.34 | ppl 208.70\n", "| epoch 2 | 1600/13484 batches | lr 4.75 | ms/batch 115.14 | loss 5.36 | ppl 212.80\n", "| epoch 2 | 1800/13484 batches | lr 4.75 | ms/batch 115.11 | loss 5.35 | ppl 209.96\n", "| epoch 2 | 2000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.32 | ppl 203.54\n", "| epoch 2 | 2200/13484 batches | lr 4.75 | ms/batch 115.15 | loss 5.33 | ppl 205.82\n", "| epoch 2 | 2400/13484 batches | lr 4.75 | ms/batch 115.23 | loss 5.34 | ppl 208.95\n", "| epoch 2 | 2600/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.29 | ppl 199.16\n", "| epoch 2 | 2800/13484 batches | lr 4.75 | ms/batch 115.16 | loss 5.34 | ppl 208.19\n", "| epoch 2 | 3000/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.33 | ppl 205.88\n", "| epoch 2 | 3200/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.29 | ppl 198.11\n", "| epoch 2 | 3400/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.37 | ppl 214.29\n", "| epoch 2 | 3600/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.31 | ppl 202.72\n", "| epoch 2 | 3800/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.32 | ppl 203.84\n", "| epoch 2 | 4000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.24 | ppl 189.14\n", "| epoch 2 | 4200/13484 batches | lr 4.75 | ms/batch 115.15 | loss 5.28 | ppl 196.95\n", "| epoch 2 | 4400/13484 batches | lr 4.75 | ms/batch 115.17 | loss 5.29 | ppl 198.84\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "| epoch 2 | 4600/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.35 | ppl 210.15\n", "| epoch 2 | 4800/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.32 | ppl 204.37\n", "| epoch 2 | 5000/13484 batches | lr 4.75 | ms/batch 115.29 | loss 5.33 | ppl 205.42\n", "| epoch 2 | 5200/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.31 | ppl 201.44\n", "| epoch 2 | 5400/13484 batches | lr 4.75 | ms/batch 115.23 | loss 5.30 | ppl 200.48\n", "| epoch 2 | 5600/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.29 | ppl 197.76\n", "| epoch 2 | 5800/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.34 | ppl 207.65\n", "| epoch 2 | 6000/13484 batches | lr 4.75 | ms/batch 115.11 | loss 5.32 | ppl 204.89\n", "| epoch 2 | 6200/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.34 | ppl 209.24\n", "| epoch 2 | 6400/13484 batches | lr 4.75 | ms/batch 115.14 | loss 5.31 | ppl 201.48\n", "| epoch 2 | 6600/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.36 | ppl 212.87\n", "| epoch 2 | 6800/13484 batches | lr 4.75 | ms/batch 115.13 | loss 5.29 | ppl 198.41\n", "| epoch 2 | 7000/13484 batches | lr 4.75 | ms/batch 115.16 | loss 5.35 | ppl 211.39\n", "| epoch 2 | 7200/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.30 | ppl 199.94\n", "| epoch 2 | 7400/13484 batches | lr 4.75 | ms/batch 115.11 | loss 5.30 | ppl 200.81\n", "| epoch 2 | 7600/13484 batches | lr 4.75 | ms/batch 115.31 | loss 5.35 | ppl 211.20\n", "| epoch 2 | 7800/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.31 | ppl 201.93\n", "| epoch 2 | 8000/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.29 | ppl 198.24\n", "| epoch 2 | 8200/13484 batches | lr 4.75 | ms/batch 115.14 | loss 5.27 | ppl 194.75\n", "| epoch 2 | 8400/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.29 | ppl 198.48\n", "| epoch 2 | 8600/13484 batches | lr 4.75 | ms/batch 115.13 | loss 5.29 | ppl 198.11\n", "| epoch 2 | 8800/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.34 | ppl 207.62\n", "| epoch 2 | 9000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.33 | ppl 205.55\n", "| epoch 2 | 9200/13484 batches | lr 4.75 | ms/batch 115.27 | loss 5.33 | ppl 206.24\n", "| epoch 2 | 9400/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.31 | ppl 201.81\n", "| epoch 2 | 9600/13484 batches | lr 4.75 | ms/batch 115.25 | loss 5.29 | ppl 198.63\n", "| epoch 2 | 9800/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.26 | ppl 192.87\n", "| epoch 2 | 10000/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.30 | ppl 199.77\n", "| epoch 2 | 10200/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.25 | ppl 191.30\n", "| epoch 2 | 10400/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.22 | ppl 184.78\n", "| epoch 2 | 10600/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.27 | ppl 194.07\n", "| epoch 2 | 10800/13484 batches | lr 4.75 | ms/batch 115.23 | loss 5.30 | ppl 200.53\n", "| epoch 2 | 11000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.29 | ppl 198.68\n", "| epoch 2 | 11200/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.28 | ppl 196.43\n", "| epoch 2 | 11400/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.23 | ppl 186.61\n", "| epoch 2 | 11600/13484 batches | lr 4.75 | ms/batch 115.13 | loss 5.27 | ppl 195.11\n", "| epoch 2 | 11800/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.23 | ppl 186.19\n", "| epoch 2 | 12000/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.31 | ppl 202.19\n", "| epoch 2 | 12200/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.22 | ppl 184.46\n", "| epoch 2 | 12400/13484 batches | lr 4.75 | ms/batch 115.32 | loss 5.23 | ppl 187.26\n", "| epoch 2 | 12600/13484 batches | lr 4.75 | ms/batch 115.31 | loss 5.25 | ppl 189.65\n", "| epoch 2 | 12800/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.28 | ppl 196.25\n", "| epoch 2 | 13000/13484 batches | lr 4.75 | ms/batch 115.35 | loss 5.28 | ppl 196.31\n", "| epoch 2 | 13200/13484 batches | lr 4.75 | ms/batch 115.32 | loss 5.28 | ppl 195.61\n", "| epoch 2 | 13400/13484 batches | lr 4.75 | ms/batch 115.27 | loss 5.28 | ppl 195.80\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 2 | time: 1625.71s | valid loss 5.24 | valid ppl 188.48\n", "-----------------------------------------------------------------------------------------\n", "| epoch 3 | 200/13484 batches | lr 4.51 | ms/batch 115.84 | loss 5.32 | ppl 205.41\n", "| epoch 3 | 400/13484 batches | lr 4.51 | ms/batch 115.17 | loss 5.28 | ppl 195.56\n", "| epoch 3 | 600/13484 batches | lr 4.51 | ms/batch 115.12 | loss 5.22 | ppl 184.23\n", "| epoch 3 | 800/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.23 | ppl 187.41\n", "| epoch 3 | 1000/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.23 | ppl 186.77\n", "| epoch 3 | 1200/13484 batches | lr 4.51 | ms/batch 115.14 | loss 5.22 | ppl 184.68\n", "| epoch 3 | 1400/13484 batches | lr 4.51 | ms/batch 115.18 | loss 5.20 | ppl 181.17\n", "| epoch 3 | 1600/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.25 | ppl 191.20\n", "| epoch 3 | 1800/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.23 | ppl 186.87\n", "| epoch 3 | 2000/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.19 | ppl 180.16\n", "| epoch 3 | 2200/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.21 | ppl 183.82\n", "| epoch 3 | 2400/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.21 | ppl 182.76\n", "| epoch 3 | 2600/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.19 | ppl 180.25\n", "| epoch 3 | 2800/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.22 | ppl 185.75\n", "| epoch 3 | 3000/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.21 | ppl 183.06\n", "| epoch 3 | 3200/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.17 | ppl 176.28\n", "| epoch 3 | 3400/13484 batches | lr 4.51 | ms/batch 115.16 | loss 5.24 | ppl 187.88\n", "| epoch 3 | 3600/13484 batches | lr 4.51 | ms/batch 115.16 | loss 5.21 | ppl 182.87\n", "| epoch 3 | 3800/13484 batches | lr 4.51 | ms/batch 115.18 | loss 5.21 | ppl 182.52\n", "| epoch 3 | 4000/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.15 | ppl 172.43\n", "| epoch 3 | 4200/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.18 | ppl 177.72\n", "| epoch 3 | 4400/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.19 | ppl 179.22\n", "| epoch 3 | 4600/13484 batches | lr 4.51 | ms/batch 115.26 | loss 5.24 | ppl 187.99\n", "| epoch 3 | 4800/13484 batches | lr 4.51 | ms/batch 115.17 | loss 5.24 | ppl 188.20\n", "| epoch 3 | 5000/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.22 | ppl 184.24\n", "| epoch 3 | 5200/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.21 | ppl 184.01\n", "| epoch 3 | 5400/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.19 | ppl 178.94\n", "| epoch 3 | 5600/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.19 | ppl 180.15\n", "| epoch 3 | 5800/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.20 | ppl 181.24\n", "| epoch 3 | 6000/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.22 | ppl 184.08\n", "| epoch 3 | 6200/13484 batches | lr 4.51 | ms/batch 115.29 | loss 5.24 | ppl 187.77\n", "| epoch 3 | 6400/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.21 | ppl 182.36\n", "| epoch 3 | 6600/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.25 | ppl 190.52\n", "| epoch 3 | 6800/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.20 | ppl 180.56\n", "| epoch 3 | 7000/13484 batches | lr 4.51 | ms/batch 115.14 | loss 5.23 | ppl 186.73\n", "| epoch 3 | 7200/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.19 | ppl 179.90\n", "| epoch 3 | 7400/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.21 | ppl 182.43\n", "| epoch 3 | 7600/13484 batches | lr 4.51 | ms/batch 115.09 | loss 5.20 | ppl 181.48\n", "| epoch 3 | 7800/13484 batches | lr 4.51 | ms/batch 115.26 | loss 5.22 | ppl 185.25\n", "| epoch 3 | 8000/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.18 | ppl 178.05\n", "| epoch 3 | 8200/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.18 | ppl 178.41\n", "| epoch 3 | 8400/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.20 | ppl 181.07\n", "| epoch 3 | 8600/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.20 | ppl 182.10\n", "| epoch 3 | 8800/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.18 | ppl 177.86\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "| epoch 3 | 9000/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.21 | ppl 182.88\n", "| epoch 3 | 9200/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.21 | ppl 183.47\n", "| epoch 3 | 9400/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.22 | ppl 184.25\n", "| epoch 3 | 9600/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.18 | ppl 177.24\n", "| epoch 3 | 9800/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.16 | ppl 174.24\n", "| epoch 3 | 10000/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.21 | ppl 182.29\n", "| epoch 3 | 10200/13484 batches | lr 4.51 | ms/batch 115.13 | loss 5.17 | ppl 175.34\n", "| epoch 3 | 10400/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.14 | ppl 170.79\n", "| epoch 3 | 10600/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.17 | ppl 176.55\n", "| epoch 3 | 10800/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.22 | ppl 185.77\n", "| epoch 3 | 11000/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.19 | ppl 179.38\n", "| epoch 3 | 11200/13484 batches | lr 4.51 | ms/batch 115.30 | loss 5.19 | ppl 179.59\n", "| epoch 3 | 11400/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.14 | ppl 171.38\n", "| epoch 3 | 11600/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.18 | ppl 178.51\n", "| epoch 3 | 11800/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.16 | ppl 174.52\n", "| epoch 3 | 12000/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.20 | ppl 181.28\n", "| epoch 3 | 12200/13484 batches | lr 4.51 | ms/batch 115.13 | loss 5.14 | ppl 170.89\n", "| epoch 3 | 12400/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.14 | ppl 169.88\n", "| epoch 3 | 12600/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.15 | ppl 172.67\n", "| epoch 3 | 12800/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.18 | ppl 176.89\n", "| epoch 3 | 13000/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.19 | ppl 179.90\n", "| epoch 3 | 13200/13484 batches | lr 4.51 | ms/batch 115.29 | loss 5.17 | ppl 175.90\n", "| epoch 3 | 13400/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.20 | ppl 182.07\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 3 | time: 1625.86s | valid loss 5.19 | valid ppl 178.94\n", "-----------------------------------------------------------------------------------------\n", "| epoch 4 | 200/13484 batches | lr 4.29 | ms/batch 115.82 | loss 5.22 | ppl 184.03\n", "| epoch 4 | 400/13484 batches | lr 4.29 | ms/batch 115.24 | loss 5.16 | ppl 174.84\n", "| epoch 4 | 600/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.14 | ppl 170.18\n", "| epoch 4 | 800/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.15 | ppl 171.64\n", "| epoch 4 | 1000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.14 | ppl 171.26\n", "| epoch 4 | 1200/13484 batches | lr 4.29 | ms/batch 115.14 | loss 5.14 | ppl 171.18\n", "| epoch 4 | 1400/13484 batches | lr 4.29 | ms/batch 115.13 | loss 5.12 | ppl 166.55\n", "| epoch 4 | 1600/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.17 | ppl 176.35\n", "| epoch 4 | 1800/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.15 | ppl 172.34\n", "| epoch 4 | 2000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.13 | ppl 169.46\n", "| epoch 4 | 2200/13484 batches | lr 4.29 | ms/batch 115.23 | loss 5.16 | ppl 173.74\n", "| epoch 4 | 2400/13484 batches | lr 4.29 | ms/batch 115.14 | loss 5.14 | ppl 170.76\n", "| epoch 4 | 2600/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.11 | ppl 165.36\n", "| epoch 4 | 2800/13484 batches | lr 4.29 | ms/batch 115.14 | loss 5.15 | ppl 173.15\n", "| epoch 4 | 3000/13484 batches | lr 4.29 | ms/batch 115.15 | loss 5.14 | ppl 171.39\n", "| epoch 4 | 3200/13484 batches | lr 4.29 | ms/batch 115.27 | loss 5.10 | ppl 164.27\n", "| epoch 4 | 3400/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.16 | ppl 174.64\n", "| epoch 4 | 3600/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.13 | ppl 168.98\n", "| epoch 4 | 3800/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.12 | ppl 167.42\n", "| epoch 4 | 4000/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.08 | ppl 161.02\n", "| epoch 4 | 4200/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.11 | ppl 165.33\n", "| epoch 4 | 4400/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.11 | ppl 165.61\n", "| epoch 4 | 4600/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.16 | ppl 173.90\n", "| epoch 4 | 4800/13484 batches | lr 4.29 | ms/batch 115.27 | loss 5.15 | ppl 172.81\n", "| epoch 4 | 5000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.14 | ppl 169.98\n", "| epoch 4 | 5200/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.13 | ppl 168.94\n", "| epoch 4 | 5400/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.10 | ppl 164.28\n", "| epoch 4 | 5600/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.12 | ppl 167.23\n", "| epoch 4 | 5800/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.12 | ppl 167.63\n", "| epoch 4 | 6000/13484 batches | lr 4.29 | ms/batch 115.34 | loss 5.14 | ppl 170.26\n", "| epoch 4 | 6200/13484 batches | lr 4.29 | ms/batch 115.31 | loss 5.18 | ppl 177.13\n", "| epoch 4 | 6400/13484 batches | lr 4.29 | ms/batch 115.27 | loss 5.13 | ppl 169.45\n", "| epoch 4 | 6600/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.16 | ppl 174.83\n", "| epoch 4 | 6800/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.11 | ppl 165.20\n", "| epoch 4 | 7000/13484 batches | lr 4.29 | ms/batch 115.19 | loss 5.16 | ppl 174.72\n", "| epoch 4 | 7200/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.12 | ppl 167.83\n", "| epoch 4 | 7400/13484 batches | lr 4.29 | ms/batch 115.29 | loss 5.12 | ppl 167.13\n", "| epoch 4 | 7600/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.13 | ppl 168.29\n", "| epoch 4 | 7800/13484 batches | lr 4.29 | ms/batch 115.30 | loss 5.12 | ppl 167.88\n", "| epoch 4 | 8000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.11 | ppl 165.65\n", "| epoch 4 | 8200/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.10 | ppl 164.16\n", "| epoch 4 | 8400/13484 batches | lr 4.29 | ms/batch 115.29 | loss 5.12 | ppl 166.71\n", "| epoch 4 | 8600/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.14 | ppl 169.91\n", "| epoch 4 | 8800/13484 batches | lr 4.29 | ms/batch 115.30 | loss 5.11 | ppl 166.00\n", "| epoch 4 | 9000/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.13 | ppl 169.67\n", "| epoch 4 | 9200/13484 batches | lr 4.29 | ms/batch 115.31 | loss 5.13 | ppl 169.46\n", "| epoch 4 | 9400/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.15 | ppl 171.85\n", "| epoch 4 | 9600/13484 batches | lr 4.29 | ms/batch 115.29 | loss 5.11 | ppl 165.01\n", "| epoch 4 | 9800/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.09 | ppl 162.51\n", "| epoch 4 | 10000/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.12 | ppl 167.81\n", "| epoch 4 | 10200/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.10 | ppl 163.43\n", "| epoch 4 | 10400/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.07 | ppl 158.79\n", "| epoch 4 | 10600/13484 batches | lr 4.29 | ms/batch 115.30 | loss 5.10 | ppl 163.54\n", "| epoch 4 | 10800/13484 batches | lr 4.29 | ms/batch 115.39 | loss 5.12 | ppl 167.03\n", "| epoch 4 | 11000/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.11 | ppl 166.04\n", "| epoch 4 | 11200/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.11 | ppl 165.67\n", "| epoch 4 | 11400/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.06 | ppl 157.42\n", "| epoch 4 | 11600/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.10 | ppl 164.17\n", "| epoch 4 | 11800/13484 batches | lr 4.29 | ms/batch 115.36 | loss 5.07 | ppl 159.41\n", "| epoch 4 | 12000/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.13 | ppl 168.33\n", "| epoch 4 | 12200/13484 batches | lr 4.29 | ms/batch 115.24 | loss 5.05 | ppl 155.52\n", "| epoch 4 | 12400/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.07 | ppl 159.62\n", "| epoch 4 | 12600/13484 batches | lr 4.29 | ms/batch 115.24 | loss 5.09 | ppl 161.65\n", "| epoch 4 | 12800/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.10 | ppl 164.49\n", "| epoch 4 | 13000/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.10 | ppl 163.47\n", "| epoch 4 | 13200/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.09 | ppl 162.89\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "| epoch 4 | 13400/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.12 | ppl 166.66\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 4 | time: 1626.36s | valid loss 5.13 | valid ppl 168.54\n", "-----------------------------------------------------------------------------------------\n", "| epoch 5 | 200/13484 batches | lr 4.07 | ms/batch 115.90 | loss 5.14 | ppl 170.65\n", "| epoch 5 | 400/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.09 | ppl 163.18\n", "| epoch 5 | 600/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.07 | ppl 159.22\n", "| epoch 5 | 800/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.08 | ppl 160.60\n", "| epoch 5 | 1000/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.08 | ppl 160.49\n", "| epoch 5 | 1200/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.07 | ppl 158.86\n", "| epoch 5 | 1400/13484 batches | lr 4.07 | ms/batch 115.14 | loss 5.06 | ppl 156.88\n", "| epoch 5 | 1600/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.10 | ppl 164.68\n", "| epoch 5 | 1800/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.09 | ppl 161.68\n", "| epoch 5 | 2000/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.05 | ppl 156.19\n", "| epoch 5 | 2200/13484 batches | lr 4.07 | ms/batch 115.16 | loss 5.06 | ppl 157.65\n", "| epoch 5 | 2400/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.05 | ppl 156.29\n", "| epoch 5 | 2600/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.04 | ppl 155.08\n", "| epoch 5 | 2800/13484 batches | lr 4.07 | ms/batch 115.12 | loss 5.08 | ppl 160.79\n", "| epoch 5 | 3000/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.06 | ppl 157.93\n", "| epoch 5 | 3200/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.03 | ppl 153.59\n", "| epoch 5 | 3400/13484 batches | lr 4.07 | ms/batch 115.24 | loss 5.10 | ppl 164.69\n", "| epoch 5 | 3600/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.07 | ppl 159.67\n", "| epoch 5 | 3800/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.05 | ppl 156.33\n", "| epoch 5 | 4000/13484 batches | lr 4.07 | ms/batch 115.30 | loss 5.00 | ppl 148.52\n", "| epoch 5 | 4200/13484 batches | lr 4.07 | ms/batch 115.16 | loss 5.03 | ppl 153.04\n", "| epoch 5 | 4400/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.04 | ppl 155.12\n", "| epoch 5 | 4600/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.09 | ppl 162.86\n", "| epoch 5 | 4800/13484 batches | lr 4.07 | ms/batch 115.21 | loss 5.07 | ppl 159.17\n", "| epoch 5 | 5000/13484 batches | lr 4.07 | ms/batch 115.27 | loss 5.06 | ppl 157.50\n", "| epoch 5 | 5200/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.06 | ppl 157.70\n", "| epoch 5 | 5400/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.04 | ppl 154.31\n", "| epoch 5 | 5600/13484 batches | lr 4.07 | ms/batch 115.18 | loss 5.05 | ppl 156.47\n", "| epoch 5 | 5800/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.06 | ppl 157.43\n", "| epoch 5 | 6000/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.07 | ppl 159.33\n", "| epoch 5 | 6200/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.09 | ppl 163.19\n", "| epoch 5 | 6400/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.07 | ppl 159.77\n", "| epoch 5 | 6600/13484 batches | lr 4.07 | ms/batch 115.28 | loss 5.09 | ppl 163.17\n", "| epoch 5 | 6800/13484 batches | lr 4.07 | ms/batch 115.11 | loss 5.03 | ppl 153.03\n", "| epoch 5 | 7000/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.09 | ppl 161.90\n", "| epoch 5 | 7200/13484 batches | lr 4.07 | ms/batch 115.21 | loss 5.06 | ppl 156.90\n", "| epoch 5 | 7400/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.07 | ppl 159.02\n", "| epoch 5 | 7600/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.05 | ppl 156.02\n", "| epoch 5 | 7800/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.06 | ppl 157.20\n", "| epoch 5 | 8000/13484 batches | lr 4.07 | ms/batch 115.20 | loss 5.04 | ppl 154.56\n", "| epoch 5 | 8200/13484 batches | lr 4.07 | ms/batch 115.20 | loss 5.03 | ppl 152.46\n", "| epoch 5 | 8400/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.06 | ppl 157.62\n", "| epoch 5 | 8600/13484 batches | lr 4.07 | ms/batch 115.28 | loss 5.07 | ppl 158.74\n", "| epoch 5 | 8800/13484 batches | lr 4.07 | ms/batch 115.30 | loss 5.04 | ppl 154.53\n", "| epoch 5 | 9000/13484 batches | lr 4.07 | ms/batch 115.31 | loss 5.06 | ppl 157.02\n", "| epoch 5 | 9200/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.07 | ppl 159.14\n", "| epoch 5 | 9400/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.07 | ppl 159.15\n", "| epoch 5 | 9600/13484 batches | lr 4.07 | ms/batch 115.24 | loss 5.04 | ppl 153.89\n", "| epoch 5 | 9800/13484 batches | lr 4.07 | ms/batch 115.27 | loss 5.02 | ppl 151.96\n", "| epoch 5 | 10000/13484 batches | lr 4.07 | ms/batch 115.24 | loss 5.05 | ppl 156.58\n", "| epoch 5 | 10200/13484 batches | lr 4.07 | ms/batch 115.30 | loss 5.02 | ppl 152.10\n", "| epoch 5 | 10400/13484 batches | lr 4.07 | ms/batch 115.33 | loss 4.98 | ppl 146.11\n", "| epoch 5 | 10600/13484 batches | lr 4.07 | ms/batch 115.26 | loss 5.03 | ppl 153.27\n", "| epoch 5 | 10800/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.06 | ppl 157.44\n", "| epoch 5 | 11000/13484 batches | lr 4.07 | ms/batch 115.33 | loss 5.05 | ppl 156.34\n", "| epoch 5 | 11200/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.04 | ppl 154.36\n", "| epoch 5 | 11400/13484 batches | lr 4.07 | ms/batch 115.27 | loss 5.00 | ppl 148.51\n", "| epoch 5 | 11600/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.03 | ppl 153.09\n", "| epoch 5 | 11800/13484 batches | lr 4.07 | ms/batch 115.26 | loss 5.00 | ppl 148.85\n", "| epoch 5 | 12000/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.06 | ppl 156.93\n", "| epoch 5 | 12200/13484 batches | lr 4.07 | ms/batch 115.22 | loss 4.98 | ppl 145.89\n", "| epoch 5 | 12400/13484 batches | lr 4.07 | ms/batch 115.20 | loss 5.00 | ppl 148.86\n", "| epoch 5 | 12600/13484 batches | lr 4.07 | ms/batch 115.33 | loss 5.02 | ppl 151.32\n", "| epoch 5 | 12800/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.04 | ppl 154.42\n", "| epoch 5 | 13000/13484 batches | lr 4.07 | ms/batch 115.28 | loss 5.03 | ppl 152.95\n", "| epoch 5 | 13200/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.03 | ppl 153.49\n", "| epoch 5 | 13400/13484 batches | lr 4.07 | ms/batch 115.35 | loss 5.05 | ppl 155.92\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 5 | time: 1625.93s | valid loss 5.10 | valid ppl 164.06\n", "-----------------------------------------------------------------------------------------\n", "| epoch 6 | 200/13484 batches | lr 3.87 | ms/batch 115.82 | loss 5.07 | ppl 159.79\n", "| epoch 6 | 400/13484 batches | lr 3.87 | ms/batch 115.22 | loss 5.04 | ppl 153.75\n", "| epoch 6 | 600/13484 batches | lr 3.87 | ms/batch 115.14 | loss 5.01 | ppl 150.20\n", "| epoch 6 | 800/13484 batches | lr 3.87 | ms/batch 115.25 | loss 5.01 | ppl 150.24\n", "| epoch 6 | 1000/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.01 | ppl 149.42\n", "| epoch 6 | 1200/13484 batches | lr 3.87 | ms/batch 115.09 | loss 5.01 | ppl 150.28\n", "| epoch 6 | 1400/13484 batches | lr 3.87 | ms/batch 115.18 | loss 5.00 | ppl 148.53\n", "| epoch 6 | 1600/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.05 | ppl 156.45\n", "| epoch 6 | 1800/13484 batches | lr 3.87 | ms/batch 115.17 | loss 5.02 | ppl 151.97\n", "| epoch 6 | 2000/13484 batches | lr 3.87 | ms/batch 115.14 | loss 5.00 | ppl 147.68\n", "| epoch 6 | 2200/13484 batches | lr 3.87 | ms/batch 115.22 | loss 5.00 | ppl 148.99\n", "| epoch 6 | 2400/13484 batches | lr 3.87 | ms/batch 115.19 | loss 5.00 | ppl 147.82\n", "| epoch 6 | 2600/13484 batches | lr 3.87 | ms/batch 115.19 | loss 4.98 | ppl 145.20\n", "| epoch 6 | 2800/13484 batches | lr 3.87 | ms/batch 115.20 | loss 5.02 | ppl 152.00\n", "| epoch 6 | 3000/13484 batches | lr 3.87 | ms/batch 115.20 | loss 5.01 | ppl 149.24\n", "| epoch 6 | 3200/13484 batches | lr 3.87 | ms/batch 115.23 | loss 4.98 | ppl 145.09\n", "| epoch 6 | 3400/13484 batches | lr 3.87 | ms/batch 115.35 | loss 5.03 | ppl 153.68\n", "| epoch 6 | 3600/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.01 | ppl 149.34\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "| epoch 6 | 3800/13484 batches | lr 3.87 | ms/batch 115.20 | loss 5.00 | ppl 148.07\n", "| epoch 6 | 4000/13484 batches | lr 3.87 | ms/batch 115.32 | loss 4.94 | ppl 140.04\n", "| epoch 6 | 4200/13484 batches | lr 3.87 | ms/batch 115.21 | loss 4.97 | ppl 144.64\n", "| epoch 6 | 4400/13484 batches | lr 3.87 | ms/batch 115.20 | loss 4.99 | ppl 146.48\n", "| epoch 6 | 4600/13484 batches | lr 3.87 | ms/batch 115.18 | loss 5.03 | ppl 153.49\n", "| epoch 6 | 4800/13484 batches | lr 3.87 | ms/batch 115.30 | loss 5.01 | ppl 150.20\n", "| epoch 6 | 5000/13484 batches | lr 3.87 | ms/batch 115.24 | loss 5.00 | ppl 148.23\n", "| epoch 6 | 5200/13484 batches | lr 3.87 | ms/batch 115.22 | loss 5.00 | ppl 148.51\n", "| epoch 6 | 5400/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.98 | ppl 145.45\n", "| epoch 6 | 5600/13484 batches | lr 3.87 | ms/batch 115.25 | loss 4.99 | ppl 146.84\n", "| epoch 6 | 5800/13484 batches | lr 3.87 | ms/batch 115.29 | loss 4.99 | ppl 147.24\n", "| epoch 6 | 6000/13484 batches | lr 3.87 | ms/batch 115.19 | loss 5.01 | ppl 150.09\n", "| epoch 6 | 6200/13484 batches | lr 3.87 | ms/batch 115.21 | loss 5.03 | ppl 152.86\n", "| epoch 6 | 6400/13484 batches | lr 3.87 | ms/batch 115.17 | loss 5.02 | ppl 150.83\n", "| epoch 6 | 6600/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.05 | ppl 155.25\n", "| epoch 6 | 6800/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.00 | ppl 148.10\n", "| epoch 6 | 7000/13484 batches | lr 3.87 | ms/batch 115.35 | loss 5.03 | ppl 152.52\n", "| epoch 6 | 7200/13484 batches | lr 3.87 | ms/batch 115.25 | loss 5.00 | ppl 148.62\n", "| epoch 6 | 7400/13484 batches | lr 3.87 | ms/batch 115.30 | loss 5.00 | ppl 148.56\n", "| epoch 6 | 7600/13484 batches | lr 3.87 | ms/batch 115.25 | loss 4.99 | ppl 147.28\n", "| epoch 6 | 7800/13484 batches | lr 3.87 | ms/batch 115.24 | loss 5.00 | ppl 147.93\n", "| epoch 6 | 8000/13484 batches | lr 3.87 | ms/batch 115.24 | loss 4.98 | ppl 145.76\n", "| epoch 6 | 8200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.97 | ppl 143.39\n", "| epoch 6 | 8400/13484 batches | lr 3.87 | ms/batch 115.24 | loss 4.99 | ppl 147.14\n", "| epoch 6 | 8600/13484 batches | lr 3.87 | ms/batch 115.27 | loss 5.00 | ppl 148.00\n", "| epoch 6 | 8800/13484 batches | lr 3.87 | ms/batch 115.35 | loss 4.98 | ppl 145.27\n", "| epoch 6 | 9000/13484 batches | lr 3.87 | ms/batch 115.27 | loss 5.01 | ppl 150.06\n", "| epoch 6 | 9200/13484 batches | lr 3.87 | ms/batch 115.21 | loss 5.01 | ppl 150.09\n", "| epoch 6 | 9400/13484 batches | lr 3.87 | ms/batch 115.28 | loss 5.01 | ppl 150.08\n", "| epoch 6 | 9600/13484 batches | lr 3.87 | ms/batch 115.16 | loss 4.99 | ppl 147.55\n", "| epoch 6 | 9800/13484 batches | lr 3.87 | ms/batch 115.27 | loss 4.97 | ppl 143.67\n", "| epoch 6 | 10000/13484 batches | lr 3.87 | ms/batch 115.20 | loss 4.99 | ppl 147.66\n", "| epoch 6 | 10200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.95 | ppl 141.61\n", "| epoch 6 | 10400/13484 batches | lr 3.87 | ms/batch 115.20 | loss 4.93 | ppl 138.76\n", "| epoch 6 | 10600/13484 batches | lr 3.87 | ms/batch 115.28 | loss 4.97 | ppl 144.59\n", "| epoch 6 | 10800/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.01 | ppl 149.16\n", "| epoch 6 | 11000/13484 batches | lr 3.87 | ms/batch 115.29 | loss 5.00 | ppl 148.35\n", "| epoch 6 | 11200/13484 batches | lr 3.87 | ms/batch 115.29 | loss 5.01 | ppl 149.31\n", "| epoch 6 | 11400/13484 batches | lr 3.87 | ms/batch 115.29 | loss 4.95 | ppl 141.26\n", "| epoch 6 | 11600/13484 batches | lr 3.87 | ms/batch 115.34 | loss 4.98 | ppl 145.07\n", "| epoch 6 | 11800/13484 batches | lr 3.87 | ms/batch 115.28 | loss 4.94 | ppl 140.00\n", "| epoch 6 | 12000/13484 batches | lr 3.87 | ms/batch 115.19 | loss 5.00 | ppl 147.85\n", "| epoch 6 | 12200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.93 | ppl 137.74\n", "| epoch 6 | 12400/13484 batches | lr 3.87 | ms/batch 115.26 | loss 4.95 | ppl 140.89\n", "| epoch 6 | 12600/13484 batches | lr 3.87 | ms/batch 115.38 | loss 4.97 | ppl 143.33\n", "| epoch 6 | 12800/13484 batches | lr 3.87 | ms/batch 115.29 | loss 4.98 | ppl 145.29\n", "| epoch 6 | 13000/13484 batches | lr 3.87 | ms/batch 115.37 | loss 4.97 | ppl 144.45\n", "| epoch 6 | 13200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.98 | ppl 146.13\n", "| epoch 6 | 13400/13484 batches | lr 3.87 | ms/batch 115.33 | loss 5.00 | ppl 148.36\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 6 | time: 1626.41s | valid loss 5.09 | valid ppl 162.11\n", "-----------------------------------------------------------------------------------------\n", "| epoch 7 | 200/13484 batches | lr 3.68 | ms/batch 115.82 | loss 5.02 | ppl 151.17\n", "| epoch 7 | 400/13484 batches | lr 3.68 | ms/batch 115.20 | loss 4.98 | ppl 144.87\n", "| epoch 7 | 600/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.96 | ppl 142.39\n", "| epoch 7 | 800/13484 batches | lr 3.68 | ms/batch 115.16 | loss 4.96 | ppl 142.45\n", "| epoch 7 | 1000/13484 batches | lr 3.68 | ms/batch 115.08 | loss 4.96 | ppl 142.21\n", "| epoch 7 | 1200/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.96 | ppl 142.21\n", "| epoch 7 | 1400/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.94 | ppl 139.97\n", "| epoch 7 | 1600/13484 batches | lr 3.68 | ms/batch 115.13 | loss 4.99 | ppl 146.87\n", "| epoch 7 | 1800/13484 batches | lr 3.68 | ms/batch 115.11 | loss 4.97 | ppl 144.27\n", "| epoch 7 | 2000/13484 batches | lr 3.68 | ms/batch 115.14 | loss 4.94 | ppl 139.63\n", "| epoch 7 | 2200/13484 batches | lr 3.68 | ms/batch 115.13 | loss 4.94 | ppl 140.28\n", "| epoch 7 | 2400/13484 batches | lr 3.68 | ms/batch 115.14 | loss 4.94 | ppl 140.42\n", "| epoch 7 | 2600/13484 batches | lr 3.68 | ms/batch 115.20 | loss 4.93 | ppl 138.37\n", "| epoch 7 | 2800/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.97 | ppl 144.51\n", "| epoch 7 | 3000/13484 batches | lr 3.68 | ms/batch 115.22 | loss 4.95 | ppl 141.43\n", "| epoch 7 | 3200/13484 batches | lr 3.68 | ms/batch 115.17 | loss 4.92 | ppl 137.29\n", "| epoch 7 | 3400/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.98 | ppl 145.62\n", "| epoch 7 | 3600/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.95 | ppl 141.60\n", "| epoch 7 | 3800/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.94 | ppl 139.88\n", "| epoch 7 | 4000/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.89 | ppl 133.49\n", "| epoch 7 | 4200/13484 batches | lr 3.68 | ms/batch 115.17 | loss 4.93 | ppl 138.21\n", "| epoch 7 | 4400/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.94 | ppl 139.14\n", "| epoch 7 | 4600/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.98 | ppl 145.67\n", "| epoch 7 | 4800/13484 batches | lr 3.68 | ms/batch 115.25 | loss 4.96 | ppl 143.05\n", "| epoch 7 | 5000/13484 batches | lr 3.68 | ms/batch 115.20 | loss 4.95 | ppl 141.27\n", "| epoch 7 | 5200/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.95 | ppl 140.78\n", "| epoch 7 | 5400/13484 batches | lr 3.68 | ms/batch 115.35 | loss 4.93 | ppl 137.98\n", "| epoch 7 | 5600/13484 batches | lr 3.68 | ms/batch 115.29 | loss 4.94 | ppl 139.66\n", "| epoch 7 | 5800/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.94 | ppl 139.99\n", "| epoch 7 | 6000/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.96 | ppl 142.34\n", "| epoch 7 | 6200/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.99 | ppl 146.32\n", "| epoch 7 | 6400/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.96 | ppl 142.33\n", "| epoch 7 | 6600/13484 batches | lr 3.68 | ms/batch 115.22 | loss 4.99 | ppl 146.69\n", "| epoch 7 | 6800/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.93 | ppl 137.90\n", "| epoch 7 | 7000/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.98 | ppl 145.72\n", "| epoch 7 | 7200/13484 batches | lr 3.68 | ms/batch 115.25 | loss 4.94 | ppl 140.06\n", "| epoch 7 | 7400/13484 batches | lr 3.68 | ms/batch 115.14 | loss 4.94 | ppl 140.43\n", "| epoch 7 | 7600/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.95 | ppl 140.71\n", "| epoch 7 | 7800/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.94 | ppl 140.23\n", "| epoch 7 | 8000/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.93 | ppl 138.76\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "| epoch 7 | 8200/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.92 | ppl 136.77\n", "| epoch 7 | 8400/13484 batches | lr 3.68 | ms/batch 115.33 | loss 4.94 | ppl 139.78\n", "| epoch 7 | 8600/13484 batches | lr 3.68 | ms/batch 115.31 | loss 4.95 | ppl 141.01\n", "| epoch 7 | 8800/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.93 | ppl 138.80\n", "| epoch 7 | 9000/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.95 | ppl 141.73\n", "| epoch 7 | 9200/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.97 | ppl 144.05\n", "| epoch 7 | 9400/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.97 | ppl 144.66\n", "| epoch 7 | 9600/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.93 | ppl 138.12\n", "| epoch 7 | 9800/13484 batches | lr 3.68 | ms/batch 115.27 | loss 4.91 | ppl 135.39\n", "| epoch 7 | 10000/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.94 | ppl 140.12\n", "| epoch 7 | 10200/13484 batches | lr 3.68 | ms/batch 115.27 | loss 4.90 | ppl 134.73\n", "| epoch 7 | 10400/13484 batches | lr 3.68 | ms/batch 115.29 | loss 4.87 | ppl 130.45\n", "| epoch 7 | 10600/13484 batches | lr 3.68 | ms/batch 115.35 | loss 4.92 | ppl 137.36\n", "| epoch 7 | 10800/13484 batches | lr 3.68 | ms/batch 115.29 | loss 4.94 | ppl 140.35\n", "| epoch 7 | 11000/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.95 | ppl 141.22\n", "| epoch 7 | 11200/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.94 | ppl 139.33\n", "| epoch 7 | 11400/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.89 | ppl 133.07\n", "| epoch 7 | 11600/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.93 | ppl 137.82\n", "| epoch 7 | 11800/13484 batches | lr 3.68 | ms/batch 115.33 | loss 4.89 | ppl 132.51\n", "| epoch 7 | 12000/13484 batches | lr 3.68 | ms/batch 115.32 | loss 4.94 | ppl 139.89\n", "| epoch 7 | 12200/13484 batches | lr 3.68 | ms/batch 115.25 | loss 4.88 | ppl 131.43\n", "| epoch 7 | 12400/13484 batches | lr 3.68 | ms/batch 115.32 | loss 4.89 | ppl 133.23\n", "| epoch 7 | 12600/13484 batches | lr 3.68 | ms/batch 115.30 | loss 4.92 | ppl 136.69\n", "| epoch 7 | 12800/13484 batches | lr 3.68 | ms/batch 115.27 | loss 4.94 | ppl 139.23\n", "| epoch 7 | 13000/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.92 | ppl 136.46\n", "| epoch 7 | 13200/13484 batches | lr 3.68 | ms/batch 115.31 | loss 4.92 | ppl 137.53\n", "| epoch 7 | 13400/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.94 | ppl 140.23\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 7 | time: 1626.06s | valid loss 5.05 | valid ppl 155.94\n", "-----------------------------------------------------------------------------------------\n", "| epoch 8 | 200/13484 batches | lr 3.49 | ms/batch 115.85 | loss 4.97 | ppl 143.91\n", "| epoch 8 | 400/13484 batches | lr 3.49 | ms/batch 115.20 | loss 4.93 | ppl 138.69\n", "| epoch 8 | 600/13484 batches | lr 3.49 | ms/batch 115.25 | loss 4.92 | ppl 137.31\n", "| epoch 8 | 800/13484 batches | lr 3.49 | ms/batch 115.21 | loss 4.91 | ppl 135.14\n", "| epoch 8 | 1000/13484 batches | lr 3.49 | ms/batch 115.09 | loss 4.91 | ppl 136.03\n", "| epoch 8 | 1200/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.91 | ppl 135.25\n", "| epoch 8 | 1400/13484 batches | lr 3.49 | ms/batch 115.16 | loss 4.89 | ppl 132.45\n", "| epoch 8 | 1600/13484 batches | lr 3.49 | ms/batch 115.19 | loss 4.94 | ppl 139.52\n", "| epoch 8 | 1800/13484 batches | lr 3.49 | ms/batch 115.13 | loss 4.92 | ppl 136.90\n", "| epoch 8 | 2000/13484 batches | lr 3.49 | ms/batch 115.23 | loss 4.89 | ppl 132.80\n", "| epoch 8 | 2200/13484 batches | lr 3.49 | ms/batch 115.12 | loss 4.89 | ppl 132.74\n", "| epoch 8 | 2400/13484 batches | lr 3.49 | ms/batch 115.25 | loss 4.90 | ppl 133.92\n", "| epoch 8 | 2600/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.88 | ppl 131.38\n", "| epoch 8 | 2800/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.92 | ppl 137.14\n", "| epoch 8 | 3000/13484 batches | lr 3.49 | ms/batch 115.18 | loss 4.90 | ppl 134.47\n", "| epoch 8 | 3200/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.87 | ppl 130.24\n", "| epoch 8 | 3400/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.93 | ppl 139.00\n", "| epoch 8 | 3600/13484 batches | lr 3.49 | ms/batch 115.20 | loss 4.91 | ppl 135.20\n", "| epoch 8 | 3800/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.90 | ppl 133.96\n", "| epoch 8 | 4000/13484 batches | lr 3.49 | ms/batch 115.19 | loss 4.84 | ppl 127.05\n", "| epoch 8 | 4200/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.87 | ppl 130.76\n", "| epoch 8 | 4400/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.88 | ppl 132.28\n", "| epoch 8 | 4600/13484 batches | lr 3.49 | ms/batch 115.36 | loss 4.93 | ppl 138.46\n", "| epoch 8 | 4800/13484 batches | lr 3.49 | ms/batch 115.18 | loss 4.91 | ppl 135.37\n", "| epoch 8 | 5000/13484 batches | lr 3.49 | ms/batch 115.23 | loss 4.90 | ppl 134.12\n", "| epoch 8 | 5200/13484 batches | lr 3.49 | ms/batch 115.21 | loss 4.90 | ppl 134.65\n", "| epoch 8 | 5400/13484 batches | lr 3.49 | ms/batch 115.35 | loss 4.87 | ppl 130.93\n", "| epoch 8 | 5600/13484 batches | lr 3.49 | ms/batch 115.35 | loss 4.89 | ppl 133.28\n", "| epoch 8 | 5800/13484 batches | lr 3.49 | ms/batch 115.23 | loss 4.89 | ppl 132.54\n", "| epoch 8 | 6000/13484 batches | lr 3.49 | ms/batch 115.22 | loss 4.91 | ppl 135.15\n", "| epoch 8 | 6200/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.94 | ppl 139.25\n", "| epoch 8 | 6400/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.91 | ppl 135.37\n", "| epoch 8 | 6600/13484 batches | lr 3.49 | ms/batch 115.17 | loss 4.94 | ppl 139.28\n", "| epoch 8 | 6800/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.88 | ppl 132.05\n", "| epoch 8 | 7000/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.92 | ppl 137.41\n", "| epoch 8 | 7200/13484 batches | lr 3.49 | ms/batch 115.34 | loss 4.90 | ppl 133.68\n", "| epoch 8 | 7400/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.89 | ppl 133.58\n", "| epoch 8 | 7600/13484 batches | lr 3.49 | ms/batch 115.26 | loss 4.90 | ppl 133.64\n", "| epoch 8 | 7800/13484 batches | lr 3.49 | ms/batch 115.33 | loss 4.89 | ppl 133.55\n", "| epoch 8 | 8000/13484 batches | lr 3.49 | ms/batch 115.17 | loss 4.88 | ppl 132.23\n", "| epoch 8 | 8200/13484 batches | lr 3.49 | ms/batch 115.25 | loss 4.87 | ppl 129.93\n", "| epoch 8 | 8400/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.89 | ppl 133.16\n", "| epoch 8 | 8600/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.89 | ppl 133.49\n", "| epoch 8 | 8800/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.88 | ppl 131.42\n", "| epoch 8 | 9000/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.89 | ppl 133.59\n", "| epoch 8 | 9200/13484 batches | lr 3.49 | ms/batch 115.28 | loss 4.91 | ppl 136.20\n", "| epoch 8 | 9400/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.91 | ppl 135.54\n", "| epoch 8 | 9600/13484 batches | lr 3.49 | ms/batch 115.32 | loss 4.88 | ppl 131.19\n", "| epoch 8 | 9800/13484 batches | lr 3.49 | ms/batch 115.34 | loss 4.86 | ppl 128.72\n", "| epoch 8 | 10000/13484 batches | lr 3.49 | ms/batch 115.32 | loss 4.89 | ppl 132.80\n", "| epoch 8 | 10200/13484 batches | lr 3.49 | ms/batch 115.33 | loss 4.85 | ppl 128.25\n", "| epoch 8 | 10400/13484 batches | lr 3.49 | ms/batch 115.35 | loss 4.83 | ppl 124.93\n", "| epoch 8 | 10600/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.87 | ppl 130.59\n", "| epoch 8 | 10800/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.90 | ppl 133.78\n", "| epoch 8 | 11000/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.90 | ppl 133.75\n", "| epoch 8 | 11200/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.89 | ppl 133.33\n", "| epoch 8 | 11400/13484 batches | lr 3.49 | ms/batch 115.36 | loss 4.84 | ppl 126.25\n", "| epoch 8 | 11600/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.88 | ppl 131.70\n", "| epoch 8 | 11800/13484 batches | lr 3.49 | ms/batch 115.36 | loss 4.84 | ppl 127.09\n", "| epoch 8 | 12000/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.89 | ppl 133.44\n", "| epoch 8 | 12200/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.83 | ppl 124.78\n", "| epoch 8 | 12400/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.84 | ppl 125.91\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "| epoch 8 | 12600/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.86 | ppl 128.83\n", "| epoch 8 | 12800/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.88 | ppl 131.60\n", "| epoch 8 | 13000/13484 batches | lr 3.49 | ms/batch 115.33 | loss 4.87 | ppl 130.10\n", "| epoch 8 | 13200/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.88 | ppl 131.87\n", "| epoch 8 | 13400/13484 batches | lr 3.49 | ms/batch 115.40 | loss 4.90 | ppl 134.29\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 8 | time: 1626.66s | valid loss 5.00 | valid ppl 148.39\n", "-----------------------------------------------------------------------------------------\n", "| epoch 9 | 200/13484 batches | lr 3.32 | ms/batch 115.97 | loss 4.92 | ppl 136.72\n", "| epoch 9 | 400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.88 | ppl 131.62\n", "| epoch 9 | 600/13484 batches | lr 3.32 | ms/batch 115.30 | loss 4.85 | ppl 128.22\n", "| epoch 9 | 800/13484 batches | lr 3.32 | ms/batch 115.29 | loss 4.86 | ppl 128.84\n", "| epoch 9 | 1000/13484 batches | lr 3.32 | ms/batch 115.29 | loss 4.86 | ppl 129.65\n", "| epoch 9 | 1200/13484 batches | lr 3.32 | ms/batch 115.21 | loss 4.86 | ppl 128.93\n", "| epoch 9 | 1400/13484 batches | lr 3.32 | ms/batch 115.28 | loss 4.85 | ppl 127.80\n", "| epoch 9 | 1600/13484 batches | lr 3.32 | ms/batch 115.36 | loss 4.89 | ppl 132.74\n", "| epoch 9 | 1800/13484 batches | lr 3.32 | ms/batch 115.27 | loss 4.88 | ppl 131.14\n", "| epoch 9 | 2000/13484 batches | lr 3.32 | ms/batch 115.32 | loss 4.84 | ppl 126.60\n", "| epoch 9 | 2200/13484 batches | lr 3.32 | ms/batch 115.33 | loss 4.84 | ppl 126.74\n", "| epoch 9 | 2400/13484 batches | lr 3.32 | ms/batch 115.32 | loss 4.84 | ppl 127.02\n", "| epoch 9 | 2600/13484 batches | lr 3.32 | ms/batch 115.31 | loss 4.84 | ppl 126.21\n", "| epoch 9 | 2800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.87 | ppl 130.53\n", "| epoch 9 | 3000/13484 batches | lr 3.32 | ms/batch 115.31 | loss 4.85 | ppl 127.68\n", "| epoch 9 | 3200/13484 batches | lr 3.32 | ms/batch 115.30 | loss 4.83 | ppl 125.33\n", "| epoch 9 | 3400/13484 batches | lr 3.32 | ms/batch 115.26 | loss 4.89 | ppl 133.40\n", "| epoch 9 | 3600/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.86 | ppl 129.20\n", "| epoch 9 | 3800/13484 batches | lr 3.32 | ms/batch 115.39 | loss 4.85 | ppl 127.67\n", "| epoch 9 | 4000/13484 batches | lr 3.32 | ms/batch 115.39 | loss 4.80 | ppl 121.75\n", "| epoch 9 | 4200/13484 batches | lr 3.32 | ms/batch 115.30 | loss 4.83 | ppl 125.31\n", "| epoch 9 | 4400/13484 batches | lr 3.32 | ms/batch 115.44 | loss 4.84 | ppl 126.39\n", "| epoch 9 | 4600/13484 batches | lr 3.32 | ms/batch 115.29 | loss 4.88 | ppl 131.31\n", "| epoch 9 | 4800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.86 | ppl 129.65\n", "| epoch 9 | 5000/13484 batches | lr 3.32 | ms/batch 115.37 | loss 4.85 | ppl 128.14\n", "| epoch 9 | 5200/13484 batches | lr 3.32 | ms/batch 115.37 | loss 4.85 | ppl 128.35\n", "| epoch 9 | 5400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.83 | ppl 124.69\n", "| epoch 9 | 5600/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.85 | ppl 127.26\n", "| epoch 9 | 5800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.85 | ppl 127.16\n", "| epoch 9 | 6000/13484 batches | lr 3.32 | ms/batch 115.45 | loss 4.86 | ppl 128.86\n", "| epoch 9 | 6200/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.88 | ppl 131.85\n", "| epoch 9 | 6400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.86 | ppl 129.32\n", "| epoch 9 | 6600/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.89 | ppl 132.53\n", "| epoch 9 | 6800/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.83 | ppl 125.33\n", "| epoch 9 | 7000/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.88 | ppl 131.81\n", "| epoch 9 | 7200/13484 batches | lr 3.32 | ms/batch 115.35 | loss 4.85 | ppl 127.28\n", "| epoch 9 | 7400/13484 batches | lr 3.32 | ms/batch 115.42 | loss 4.85 | ppl 127.94\n", "| epoch 9 | 7600/13484 batches | lr 3.32 | ms/batch 115.42 | loss 4.85 | ppl 127.51\n", "| epoch 9 | 7800/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.85 | ppl 127.59\n", "| epoch 9 | 8000/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.84 | ppl 126.24\n", "| epoch 9 | 8200/13484 batches | lr 3.32 | ms/batch 115.46 | loss 4.82 | ppl 124.14\n", "| epoch 9 | 8400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.85 | ppl 127.32\n", "| epoch 9 | 8600/13484 batches | lr 3.32 | ms/batch 115.37 | loss 4.86 | ppl 128.81\n", "| epoch 9 | 8800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.83 | ppl 125.56\n", "| epoch 9 | 9000/13484 batches | lr 3.32 | ms/batch 115.35 | loss 4.85 | ppl 128.24\n", "| epoch 9 | 9200/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.87 | ppl 130.12\n", "| epoch 9 | 9400/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.86 | ppl 129.31\n", "| epoch 9 | 9600/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.84 | ppl 126.04\n", "| epoch 9 | 9800/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.81 | ppl 122.88\n", "| epoch 9 | 10000/13484 batches | lr 3.32 | ms/batch 115.43 | loss 4.84 | ppl 126.54\n", "| epoch 9 | 10200/13484 batches | lr 3.32 | ms/batch 115.43 | loss 4.80 | ppl 121.48\n", "| epoch 9 | 10400/13484 batches | lr 3.32 | ms/batch 115.35 | loss 4.78 | ppl 118.65\n", "| epoch 9 | 10600/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.83 | ppl 124.64\n", "| epoch 9 | 10800/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.85 | ppl 127.13\n", "| epoch 9 | 11000/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.85 | ppl 127.77\n", "| epoch 9 | 11200/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.84 | ppl 126.57\n", "| epoch 9 | 11400/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.79 | ppl 120.87\n", "| epoch 9 | 11600/13484 batches | lr 3.32 | ms/batch 115.33 | loss 4.83 | ppl 125.52\n", "| epoch 9 | 11800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.80 | ppl 120.94\n", "| epoch 9 | 12000/13484 batches | lr 3.32 | ms/batch 115.43 | loss 4.85 | ppl 127.12\n", "| epoch 9 | 12200/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.78 | ppl 119.42\n", "| epoch 9 | 12400/13484 batches | lr 3.32 | ms/batch 115.45 | loss 4.79 | ppl 120.40\n", "| epoch 9 | 12600/13484 batches | lr 3.32 | ms/batch 115.44 | loss 4.82 | ppl 124.30\n", "| epoch 9 | 12800/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.84 | ppl 126.27\n", "| epoch 9 | 13000/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.83 | ppl 125.13\n", "| epoch 9 | 13200/13484 batches | lr 3.32 | ms/batch 115.39 | loss 4.83 | ppl 125.83\n", "| epoch 9 | 13400/13484 batches | lr 3.32 | ms/batch 115.42 | loss 4.85 | ppl 128.06\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 9 | time: 1628.05s | valid loss 5.02 | valid ppl 150.80\n", "-----------------------------------------------------------------------------------------\n", "| epoch 10 | 200/13484 batches | lr 3.15 | ms/batch 116.03 | loss 4.88 | ppl 131.35\n", "| epoch 10 | 400/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.84 | ppl 126.70\n", "| epoch 10 | 600/13484 batches | lr 3.15 | ms/batch 115.47 | loss 4.82 | ppl 124.18\n", "| epoch 10 | 800/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.82 | ppl 123.47\n", "| epoch 10 | 1000/13484 batches | lr 3.15 | ms/batch 115.31 | loss 4.82 | ppl 124.52\n", "| epoch 10 | 1200/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.83 | ppl 124.69\n", "| epoch 10 | 1400/13484 batches | lr 3.15 | ms/batch 115.49 | loss 4.81 | ppl 122.50\n", "| epoch 10 | 1600/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.85 | ppl 127.35\n", "| epoch 10 | 1800/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.83 | ppl 124.97\n", "| epoch 10 | 2000/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.80 | ppl 121.45\n", "| epoch 10 | 2200/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.80 | ppl 121.97\n", "| epoch 10 | 2400/13484 batches | lr 3.15 | ms/batch 115.37 | loss 4.80 | ppl 122.05\n", "| epoch 10 | 2600/13484 batches | lr 3.15 | ms/batch 115.46 | loss 4.79 | ppl 120.16\n", "| epoch 10 | 2800/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.83 | ppl 125.44\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "| epoch 10 | 3000/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.81 | ppl 122.12\n", "| epoch 10 | 3200/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.79 | ppl 120.18\n", "| epoch 10 | 3400/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.84 | ppl 127.05\n", "| epoch 10 | 3600/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.82 | ppl 123.70\n", "| epoch 10 | 3800/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.80 | ppl 121.74\n", "| epoch 10 | 4000/13484 batches | lr 3.15 | ms/batch 115.49 | loss 4.76 | ppl 116.53\n", "| epoch 10 | 4200/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.78 | ppl 119.64\n", "| epoch 10 | 4400/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.80 | ppl 121.17\n", "| epoch 10 | 4600/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.84 | ppl 126.61\n", "| epoch 10 | 4800/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.83 | ppl 124.71\n", "| epoch 10 | 5000/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.81 | ppl 122.89\n", "| epoch 10 | 5200/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.81 | ppl 123.00\n", "| epoch 10 | 5400/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.79 | ppl 120.50\n", "| epoch 10 | 5600/13484 batches | lr 3.15 | ms/batch 115.38 | loss 4.80 | ppl 121.56\n", "| epoch 10 | 5800/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.80 | ppl 121.20\n", "| epoch 10 | 6000/13484 batches | lr 3.15 | ms/batch 115.38 | loss 4.82 | ppl 123.72\n", "| epoch 10 | 6200/13484 batches | lr 3.15 | ms/batch 115.35 | loss 4.85 | ppl 127.61\n", "| epoch 10 | 6400/13484 batches | lr 3.15 | ms/batch 115.32 | loss 4.82 | ppl 124.04\n", "| epoch 10 | 6600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.85 | ppl 127.34\n", "| epoch 10 | 6800/13484 batches | lr 3.15 | ms/batch 115.38 | loss 4.80 | ppl 121.21\n", "| epoch 10 | 7000/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.84 | ppl 126.43\n", "| epoch 10 | 7200/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.81 | ppl 122.41\n", "| epoch 10 | 7400/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.81 | ppl 122.46\n", "| epoch 10 | 7600/13484 batches | lr 3.15 | ms/batch 115.37 | loss 4.80 | ppl 122.05\n", "| epoch 10 | 7800/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.80 | ppl 121.58\n", "| epoch 10 | 8000/13484 batches | lr 3.15 | ms/batch 115.32 | loss 4.79 | ppl 120.04\n", "| epoch 10 | 8200/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.78 | ppl 118.97\n", "| epoch 10 | 8400/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.80 | ppl 121.55\n", "| epoch 10 | 8600/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.82 | ppl 123.48\n", "| epoch 10 | 8800/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.79 | ppl 119.98\n", "| epoch 10 | 9000/13484 batches | lr 3.15 | ms/batch 115.37 | loss 4.80 | ppl 121.60\n", "| epoch 10 | 9200/13484 batches | lr 3.15 | ms/batch 115.50 | loss 4.82 | ppl 124.26\n", "| epoch 10 | 9400/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.83 | ppl 124.68\n", "| epoch 10 | 9600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.79 | ppl 120.72\n", "| epoch 10 | 9800/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.77 | ppl 118.27\n", "| epoch 10 | 10000/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.80 | ppl 121.83\n", "| epoch 10 | 10200/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.76 | ppl 116.50\n", "| epoch 10 | 10400/13484 batches | lr 3.15 | ms/batch 115.34 | loss 4.74 | ppl 114.00\n", "| epoch 10 | 10600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.79 | ppl 119.78\n", "| epoch 10 | 10800/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.80 | ppl 121.63\n", "| epoch 10 | 11000/13484 batches | lr 3.15 | ms/batch 115.32 | loss 4.80 | ppl 121.41\n", "| epoch 10 | 11200/13484 batches | lr 3.15 | ms/batch 115.52 | loss 4.80 | ppl 121.11\n", "| epoch 10 | 11400/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.75 | ppl 115.26\n", "| epoch 10 | 11600/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.79 | ppl 120.63\n", "| epoch 10 | 11800/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.75 | ppl 115.77\n", "| epoch 10 | 12000/13484 batches | lr 3.15 | ms/batch 115.48 | loss 4.80 | ppl 121.70\n", "| epoch 10 | 12200/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.74 | ppl 114.59\n", "| epoch 10 | 12400/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.76 | ppl 116.79\n", "| epoch 10 | 12600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.77 | ppl 118.25\n", "| epoch 10 | 12800/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.79 | ppl 120.47\n", "| epoch 10 | 13000/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.78 | ppl 119.53\n", "| epoch 10 | 13200/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.79 | ppl 120.28\n", "| epoch 10 | 13400/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.81 | ppl 122.17\n", "-----------------------------------------------------------------------------------------\n", "| end of epoch 10 | time: 1628.47s | valid loss 4.98 | valid ppl 145.33\n", "-----------------------------------------------------------------------------------------\n" ] } ], "source": [ "best_val_loss = float('inf')\n", "epochs = 10\n", "best_model = None\n", "\n", "for epoch in range(1, epochs + 1):\n", " epoch_start_time = time.time()\n", " train(model)\n", " val_loss = evaluate(model, val_data)\n", " val_ppl = math.exp(val_loss)\n", " elapsed = time.time() - epoch_start_time\n", " print('-' * 89)\n", " print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '\n", " f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')\n", " print('-' * 89)\n", "\n", " if val_loss < best_val_loss:\n", " best_val_loss = val_loss\n", " best_model = copy.deepcopy(model)\n", "\n", " scheduler.step()" ] }, { "cell_type": "markdown", "id": "f0d32419", "metadata": {}, "source": [ "### print info about best model after training" ] }, { "cell_type": "code", "execution_count": 32, "id": "12fdd0aa", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=========================================================================================\n", "| End of training | test loss 4.98 | test ppl 144.89\n", "=========================================================================================\n" ] } ], "source": [ "test_loss = evaluate(best_model, test_data)\n", "test_ppl = math.exp(test_loss)\n", "print('=' * 89)\n", "print(f'| End of training | test loss {test_loss:5.2f} | '\n", " f'test ppl {test_ppl:8.2f}')\n", "print('=' * 89)" ] }, { "cell_type": "markdown", "id": "528c9f10", "metadata": {}, "source": [ "### save trained model to file" ] }, { "cell_type": "code", "execution_count": 33, "id": "848af399", "metadata": {}, "outputs": [], "source": [ "torch.save(best_model, \"pubmed-sentencecomplete.pt\")" ] }, { "cell_type": "markdown", "id": "09df56cf", "metadata": {}, "source": [ "## Now we can try to predict based on trained model" ] }, { "cell_type": "markdown", "id": "fe250072", "metadata": {}, "source": [ "### obtain iterator for predict batch " ] }, { "cell_type": "code", "execution_count": 159, "id": "afe585d6", "metadata": {}, "outputs": [], "source": [ "def predict_abstract_iter(batch):\n", " for batch in batch:\n", " yield tokenizer(batch)" ] }, { "cell_type": "markdown", "id": "b043de0a", "metadata": {}, "source": [ "### load data into tensor for model to process" ] }, { "cell_type": "code", "execution_count": 154, "id": "8bfaa8bd", "metadata": {}, "outputs": [], "source": [ "def toDataTensor(batch):\n", " predict_generator = predict_abstract_iter(batch)\n", " return [torch.tensor(vocab.lookup_indices(item)) for item in predict_generator]" ] }, { "cell_type": "markdown", "id": "a800ffea", "metadata": {}, "source": [ "### check device once again (prob not needed)" ] }, { "cell_type": "code", "execution_count": 7, "id": "6e2c35ba", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'torch' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn [7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m device \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 2\u001b[0m device\n", "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined" ] } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" ] }, { "cell_type": "markdown", "id": "bef90722", "metadata": {}, "source": [ "### optionally load model from file if it was trained already" ] }, { "cell_type": "code", "execution_count": 50, "id": "223eed8a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TransformerModel(\n", " (pos_encoder): PositionalEncoding(\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " )\n", " (transformer_encoder): TransformerEncoder(\n", " (layers): ModuleList(\n", " (0): TransformerEncoderLayer(\n", " (self_attn): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=200, out_features=200, bias=True)\n", " )\n", " (linear1): Linear(in_features=200, out_features=200, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (linear2): Linear(in_features=200, out_features=200, bias=True)\n", " (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n", " (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n", " (dropout1): Dropout(p=0.2, inplace=False)\n", " (dropout2): Dropout(p=0.2, inplace=False)\n", " )\n", " (1): TransformerEncoderLayer(\n", " (self_attn): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=200, out_features=200, bias=True)\n", " )\n", " (linear1): Linear(in_features=200, out_features=200, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (linear2): Linear(in_features=200, out_features=200, bias=True)\n", " (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n", " (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n", " (dropout1): Dropout(p=0.2, inplace=False)\n", " (dropout2): Dropout(p=0.2, inplace=False)\n", " )\n", " )\n", " )\n", " (encoder): Embedding(163987, 200)\n", " (decoder): Linear(in_features=200, out_features=163987, bias=True)\n", ")" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_model = torch.load(\"pubmed-sentencecomplete.pt\")\n", "best_model.eval()" ] }, { "cell_type": "markdown", "id": "dd71bdfc", "metadata": {}, "source": [ "### define predict function" ] }, { "cell_type": "code", "execution_count": 160, "id": "64223e87", "metadata": {}, "outputs": [], "source": [ "def predict(input_line, mask, n_predictions=3):\n", " with torch.no_grad():\n", " output = best_model(input_line.to(device), mask) \n", " predictions = []\n", " for i in range(n_predictions):\n", " next_item = output.topk(i+1)[1].view(-1)[-1].item()\n", " predict_token_index = next_item\n", " predictions.append(vocab.lookup_token(predict_token_index))\n", " \n", " return predictions" ] }, { "cell_type": "markdown", "id": "a9b7311b", "metadata": {}, "source": [ "### define input batch " ] }, { "cell_type": "code", "execution_count": 2, "id": "913628b4", "metadata": {}, "outputs": [], "source": [ "sample_batch = [\n", " \"There is\"\n", "]\n", "input_batch = sample_batch" ] }, { "cell_type": "markdown", "id": "45930a71", "metadata": {}, "source": [ "### define initial source mask for model" ] }, { "cell_type": "code", "execution_count": 3, "id": "c4bba2a1", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'generate_square_subsequent_mask' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn [3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m bptt \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 2\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(bptt)\u001b[38;5;241m.\u001b[39mto(device)\n", "\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined" ] } ], "source": [ "bptt = 2\n", "src_mask = generate_square_subsequent_mask(bptt).to(device)" ] }, { "cell_type": "markdown", "id": "5b33b9f3", "metadata": {}, "source": [ "### Execute prediction and display predicted values and choose continuation" ] }, { "cell_type": "code", "execution_count": 4, "id": "b2895698", "metadata": {}, "outputs": [], "source": [ "def predict_loop(num_of_pred):\n", " iteration = 0\n", " is_terminated = False\n", " input_batch = sample_batch\n", " while(not is_terminated):\n", " mask_size = bptt+(iteration) \n", " src_mask = generate_square_subsequent_mask(mask_size).to(device)\n", " data = toDataTensor(input_batch)\n", " \n", " for i, d in enumerate(data):\n", " predictions = predict(d, src_mask, num_of_pred)\n", " \n", " print(\"\\n Possible continuations:\")\n", " for j in range(len(predictions)):\n", " print(j + 1, \": \", predictions[j])\n", " s_index = input(input_batch[i])\n", " if(\"e\" in s_index):\n", " is_terminated = True\n", " print(\"prediction stopped.\")\n", " break\n", "\n", " print(\"Text is now:\")\n", " input_batch[i] += (\" \" + predictions[int(s_index) - 1])\n", " print(input_batch[i])\n", "\n", " iteration = iteration + 1" ] }, { "cell_type": "code", "execution_count": 5, "id": "13ed9298", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'generate_square_subsequent_mask' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn [5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mpredict_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn [4], line 7\u001b[0m, in \u001b[0;36mpredict_loop\u001b[0;34m(num_of_pred)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m(\u001b[38;5;129;01mnot\u001b[39;00m is_terminated):\n\u001b[1;32m 6\u001b[0m mask_size \u001b[38;5;241m=\u001b[39m bptt\u001b[38;5;241m+\u001b[39m(iteration) \n\u001b[0;32m----> 7\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(mask_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m toDataTensor(input_batch)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, d \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(data):\n", "\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined" ] } ], "source": [ "predict_loop(3)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }