You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1878 lines
101 KiB

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "622dfcd6",
  6. "metadata": {},
  7. "source": [
  8. "# Group 09 - Automatic Sentence Completion for PubMed"
  9. ]
  10. },
  11. {
  12. "cell_type": "markdown",
  13. "id": "5e4cec3c",
  14. "metadata": {},
  15. "source": [
  16. "### Authors:\n",
  17. "- Constantin Fuerst\t\n",
  18. "- Leonard Starke"
  19. ]
  20. },
  21. {
  22. "cell_type": "markdown",
  23. "id": "806cfb27",
  24. "metadata": {},
  25. "source": [
  26. "### link to \"Attention is All You Need\" paper describing transformer models"
  27. ]
  28. },
  29. {
  30. "cell_type": "code",
  31. "execution_count": null,
  32. "id": "fe862072",
  33. "metadata": {},
  34. "outputs": [],
  35. "source": [
  36. "https://arxiv.org/pdf/1706.03762.pdf"
  37. ]
  38. },
  39. {
  40. "cell_type": "markdown",
  41. "id": "fa161b1b",
  42. "metadata": {},
  43. "source": [
  44. "### load query data from text file "
  45. ]
  46. },
  47. {
  48. "cell_type": "code",
  49. "execution_count": 2,
  50. "id": "e1912a79",
  51. "metadata": {
  52. "scrolled": false
  53. },
  54. "outputs": [
  55. {
  56. "name": "stdout",
  57. "output_type": "stream",
  58. "text": [
  59. "--2023-01-18 15:48:24-- https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download\n",
  60. "Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'\n",
  61. "Resolving cloud.constantin-fuerst.com (cloud.constantin-fuerst.com)... 95.91.21.14\n",
  62. "Connecting to cloud.constantin-fuerst.com (cloud.constantin-fuerst.com)|95.91.21.14|:443... connected.\n",
  63. "HTTP request sent, awaiting response... 200 OK\n",
  64. "Length: 1100551 (1.0M) [text/plain]\n",
  65. "Saving to: ‘pubmed-query.txt’\n",
  66. "\n",
  67. "pubmed-query.txt 100%[===================>] 1.05M 1.61MB/s in 0.7s \n",
  68. "\n",
  69. "2023-01-18 15:48:25 (1.61 MB/s) - ‘pubmed-query.txt’ saved [1100551/1100551]\n",
  70. "\n"
  71. ]
  72. }
  73. ],
  74. "source": [
  75. "!wget https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download -O pubmed-query.txt"
  76. ]
  77. },
  78. {
  79. "cell_type": "markdown",
  80. "id": "da068411",
  81. "metadata": {},
  82. "source": [
  83. "### import modules used for parsing query data"
  84. ]
  85. },
  86. {
  87. "cell_type": "code",
  88. "execution_count": null,
  89. "id": "c10bc5a8",
  90. "metadata": {},
  91. "outputs": [],
  92. "source": [
  93. "try:\n",
  94. " from Bio import Medline\n",
  95. "except: \n",
  96. " !pip install Bio\n",
  97. " from Bio import Medline"
  98. ]
  99. },
  100. {
  101. "cell_type": "markdown",
  102. "id": "7bf15c30",
  103. "metadata": {},
  104. "source": [
  105. "### define function for loading the papers from PubMed database"
  106. ]
  107. },
  108. {
  109. "cell_type": "code",
  110. "execution_count": 1,
  111. "id": "adfb256a",
  112. "metadata": {},
  113. "outputs": [],
  114. "source": [
  115. "def getPapers(filename):\n",
  116. " pubmed_query = open(filename, encoding='utf-8')\n",
  117. " records = Medline.parse(pubmed_query)\n",
  118. " return list(records)"
  119. ]
  120. },
  121. {
  122. "cell_type": "markdown",
  123. "id": "46bc6298",
  124. "metadata": {},
  125. "source": [
  126. "### Verify that its working"
  127. ]
  128. },
  129. {
  130. "cell_type": "code",
  131. "execution_count": 4,
  132. "id": "00481ec9",
  133. "metadata": {},
  134. "outputs": [
  135. {
  136. "name": "stdout",
  137. "output_type": "stream",
  138. "text": [
  139. "Got 150000 records from the query text file\n"
  140. ]
  141. }
  142. ],
  143. "source": [
  144. "max_records = 150000\n",
  145. "records = getPapers(\"pubmed-query.txt\")\n",
  146. "records = records[:min(max_records, len(records))]\n",
  147. "print(f\"Got {len(records)} records from the query text file\")"
  148. ]
  149. },
  150. {
  151. "cell_type": "markdown",
  152. "id": "b67747c6",
  153. "metadata": {},
  154. "source": [
  155. "### Now extract abstracts from records"
  156. ]
  157. },
  158. {
  159. "cell_type": "code",
  160. "execution_count": 5,
  161. "id": "dcf5c217",
  162. "metadata": {},
  163. "outputs": [],
  164. "source": [
  165. "r_abstracts = []\n",
  166. "for r in records:\n",
  167. " if not (r.get('AB') is None):\n",
  168. " r_abstracts.append(r['AB'])"
  169. ]
  170. },
  171. {
  172. "cell_type": "markdown",
  173. "id": "e309f6fe",
  174. "metadata": {},
  175. "source": [
  176. "### Now import torch modules needed to load the data"
  177. ]
  178. },
  179. {
  180. "cell_type": "code",
  181. "execution_count": 6,
  182. "id": "c3199444",
  183. "metadata": {},
  184. "outputs": [
  185. {
  186. "name": "stderr",
  187. "output_type": "stream",
  188. "text": [
  189. "/home/hein/.local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
  190. " from .autonotebook import tqdm as notebook_tqdm\n"
  191. ]
  192. }
  193. ],
  194. "source": [
  195. "try:\n",
  196. " import torch\n",
  197. " from torch.utils.data import Dataset \n",
  198. " from torchtext.data import get_tokenizer\n",
  199. "except:\n",
  200. " !pip --default-timeout=1000 install torch\n",
  201. " !pip --default-timeout=1000 install torchtext\n",
  202. " import torch\n",
  203. " from torch.utils.data import Dataset \n",
  204. " from torchtext.data import get_tokenizer"
  205. ]
  206. },
  207. {
  208. "cell_type": "markdown",
  209. "id": "5b4007e8",
  210. "metadata": {},
  211. "source": [
  212. "### Import numpy"
  213. ]
  214. },
  215. {
  216. "cell_type": "code",
  217. "execution_count": 7,
  218. "id": "daca9db6",
  219. "metadata": {},
  220. "outputs": [],
  221. "source": [
  222. "try:\n",
  223. " import numpy as np\n",
  224. "except:\n",
  225. " !pip install numpy\n",
  226. " import numpy as np\n"
  227. ]
  228. },
  229. {
  230. "cell_type": "markdown",
  231. "id": "683ed2fc",
  232. "metadata": {},
  233. "source": [
  234. "### import math module"
  235. ]
  236. },
  237. {
  238. "cell_type": "code",
  239. "execution_count": 8,
  240. "id": "8d2312db",
  241. "metadata": {},
  242. "outputs": [],
  243. "source": [
  244. "import math"
  245. ]
  246. },
  247. {
  248. "cell_type": "markdown",
  249. "id": "4df1e449",
  250. "metadata": {},
  251. "source": [
  252. "### define token iterators"
  253. ]
  254. },
  255. {
  256. "cell_type": "code",
  257. "execution_count": 9,
  258. "id": "3f23404d",
  259. "metadata": {},
  260. "outputs": [],
  261. "source": [
  262. "train_size = math.floor(len(r_abstracts) * 0.75)\n",
  263. "val_size = math.floor(len(r_abstracts) * 0.125)\n",
  264. "test_size = math.floor(len(r_abstracts) * 0.125)"
  265. ]
  266. },
  267. {
  268. "cell_type": "code",
  269. "execution_count": 10,
  270. "id": "8a128d3c",
  271. "metadata": {},
  272. "outputs": [],
  273. "source": [
  274. "def train_abstract_iter():\n",
  275. " for abstract in r_abstracts[:train_size]:\n",
  276. " yield abstract"
  277. ]
  278. },
  279. {
  280. "cell_type": "code",
  281. "execution_count": 11,
  282. "id": "97e89986",
  283. "metadata": {},
  284. "outputs": [],
  285. "source": [
  286. "def val_abstract_iter():\n",
  287. " for abstract in r_abstracts[(train_size + 1):(train_size + val_size)]:\n",
  288. " yield abstract"
  289. ]
  290. },
  291. {
  292. "cell_type": "code",
  293. "execution_count": 12,
  294. "id": "0d6e89c4",
  295. "metadata": {},
  296. "outputs": [],
  297. "source": [
  298. "def test_abstract_iter():\n",
  299. " for abstract in r_abstracts[(train_size + val_size + 1): (train_size + val_size + test_size)]:\n",
  300. " yield abstract"
  301. ]
  302. },
  303. {
  304. "cell_type": "markdown",
  305. "id": "e5e9c5a2",
  306. "metadata": {},
  307. "source": [
  308. "### define Tokenize function"
  309. ]
  310. },
  311. {
  312. "cell_type": "code",
  313. "execution_count": 13,
  314. "id": "0bdbc40a",
  315. "metadata": {},
  316. "outputs": [],
  317. "source": [
  318. "tokenizer = get_tokenizer(\"basic_english\")\n",
  319. "def tokenize_abstract_iter():\n",
  320. " for abstract in r_abstracts:\n",
  321. " yield tokenizer(abstract)"
  322. ]
  323. },
  324. {
  325. "cell_type": "markdown",
  326. "id": "37da40bb",
  327. "metadata": {},
  328. "source": [
  329. "### Map every word to an id to store inside torch tensor"
  330. ]
  331. },
  332. {
  333. "cell_type": "code",
  334. "execution_count": 14,
  335. "id": "a438ab1f",
  336. "metadata": {},
  337. "outputs": [],
  338. "source": [
  339. "from torchtext.vocab import build_vocab_from_iterator\n",
  340. "token_generator = tokenize_abstract_iter()\n",
  341. "vocab = build_vocab_from_iterator(token_generator, specials=['<unk>'])\n",
  342. "vocab.set_default_index(vocab['<unk>'])\n"
  343. ]
  344. },
  345. {
  346. "cell_type": "markdown",
  347. "id": "221bdc48",
  348. "metadata": {},
  349. "source": [
  350. "### now convert to tensor\n"
  351. ]
  352. },
  353. {
  354. "cell_type": "code",
  355. "execution_count": 15,
  356. "id": "0e5bc361",
  357. "metadata": {},
  358. "outputs": [],
  359. "source": [
  360. "def data_process(tokens_iter):\n",
  361. " \"\"\"Converts raw text into a flat Tensor.\"\"\"\n",
  362. " data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in tokens_iter]\n",
  363. " return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))"
  364. ]
  365. },
  366. {
  367. "cell_type": "code",
  368. "execution_count": 16,
  369. "id": "dfd7400d",
  370. "metadata": {},
  371. "outputs": [],
  372. "source": [
  373. "train_generator = train_abstract_iter()\n",
  374. "val_generator = val_abstract_iter()\n",
  375. "test_generator = test_abstract_iter()\n",
  376. "train_data = data_process(train_generator)\n",
  377. "val_data = data_process(val_generator)\n",
  378. "test_data = data_process(test_generator)"
  379. ]
  380. },
  381. {
  382. "cell_type": "markdown",
  383. "id": "c49a2734",
  384. "metadata": {},
  385. "source": [
  386. "### check gpu"
  387. ]
  388. },
  389. {
  390. "cell_type": "code",
  391. "execution_count": 20,
  392. "id": "c155ee31",
  393. "metadata": {},
  394. "outputs": [],
  395. "source": [
  396. "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
  397. ]
  398. },
  399. {
  400. "cell_type": "code",
  401. "execution_count": 21,
  402. "id": "79b2d248",
  403. "metadata": {},
  404. "outputs": [
  405. {
  406. "data": {
  407. "text/plain": [
  408. "device(type='cuda')"
  409. ]
  410. },
  411. "execution_count": 21,
  412. "metadata": {},
  413. "output_type": "execute_result"
  414. }
  415. ],
  416. "source": [
  417. "device"
  418. ]
  419. },
  420. {
  421. "cell_type": "markdown",
  422. "id": "2150ba71",
  423. "metadata": {},
  424. "source": [
  425. "### define model"
  426. ]
  427. },
  428. {
  429. "cell_type": "code",
  430. "execution_count": 22,
  431. "id": "a33d722f",
  432. "metadata": {},
  433. "outputs": [],
  434. "source": [
  435. "from typing import Tuple\n",
  436. "\n",
  437. "from torch import nn, Tensor\n",
  438. "import torch.nn.functional as F\n",
  439. "from torch.nn import TransformerEncoder, TransformerEncoderLayer\n",
  440. "from torch.utils.data import dataset\n",
  441. "\n",
  442. "class TransformerModel(nn.Module):\n",
  443. "\n",
  444. " def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,\n",
  445. " nlayers: int, dropout: float = 0.5):\n",
  446. " super().__init__()\n",
  447. " self.model_type = 'Transformer'\n",
  448. " self.pos_encoder = PositionalEncoding(d_model, dropout)\n",
  449. " encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)\n",
  450. " self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n",
  451. " self.encoder = nn.Embedding(ntoken, d_model)\n",
  452. " self.d_model = d_model\n",
  453. " self.decoder = nn.Linear(d_model, ntoken)\n",
  454. "\n",
  455. " self.init_weights()\n",
  456. "\n",
  457. " def init_weights(self) -> None:\n",
  458. " initrange = 0.1\n",
  459. " self.encoder.weight.data.uniform_(-initrange, initrange)\n",
  460. " self.decoder.bias.data.zero_()\n",
  461. " self.decoder.weight.data.uniform_(-initrange, initrange)\n",
  462. "\n",
  463. " def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:\n",
  464. " \"\"\"\n",
  465. " Args:\n",
  466. " src: Tensor, shape [seq_len, batch_size]\n",
  467. " src_mask: Tensor, shape [seq_len, seq_len]\n",
  468. "\n",
  469. " Returns:\n",
  470. " output Tensor of shape [seq_len, batch_size, ntoken]\n",
  471. " \"\"\"\n",
  472. " src = self.encoder(src) * math.sqrt(self.d_model)\n",
  473. " src = self.pos_encoder(src)\n",
  474. " output = self.transformer_encoder(src, src_mask)\n",
  475. " output = self.decoder(output)\n",
  476. " return output\n",
  477. "\n",
  478. "\n",
  479. "def generate_square_subsequent_mask(sz: int) -> Tensor:\n",
  480. " \"\"\"Generates an upper-triangular matrix of -inf, with zeros on diag.\"\"\"\n",
  481. " return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)"
  482. ]
  483. },
  484. {
  485. "cell_type": "markdown",
  486. "id": "23268efe",
  487. "metadata": {},
  488. "source": [
  489. "### define pos encoder"
  490. ]
  491. },
  492. {
  493. "cell_type": "code",
  494. "execution_count": 23,
  495. "id": "c2f6d33b",
  496. "metadata": {},
  497. "outputs": [],
  498. "source": [
  499. "class PositionalEncoding(nn.Module):\n",
  500. "\n",
  501. " def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):\n",
  502. " super().__init__()\n",
  503. " self.dropout = nn.Dropout(p=dropout)\n",
  504. "\n",
  505. " position = torch.arange(max_len).unsqueeze(1)\n",
  506. " div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n",
  507. " pe = torch.zeros(max_len, 1, d_model)\n",
  508. " pe[:, 0, 0::2] = torch.sin(position * div_term)\n",
  509. " pe[:, 0, 1::2] = torch.cos(position * div_term)\n",
  510. " self.register_buffer('pe', pe)\n",
  511. "\n",
  512. " def forward(self, x: Tensor) -> Tensor:\n",
  513. " \"\"\"\n",
  514. " Args:\n",
  515. " x: Tensor, shape [seq_len, batch_size, embedding_dim]\n",
  516. " \"\"\"\n",
  517. " x = x + self.pe[:x.size(0)]\n",
  518. " return self.dropout(x)\n"
  519. ]
  520. },
  521. {
  522. "cell_type": "markdown",
  523. "id": "306352f5",
  524. "metadata": {},
  525. "source": [
  526. "### define function to create batches of data and create batches"
  527. ]
  528. },
  529. {
  530. "cell_type": "code",
  531. "execution_count": 24,
  532. "id": "9e184841",
  533. "metadata": {},
  534. "outputs": [],
  535. "source": [
  536. "def batchify(data: Tensor, bsz: int) -> Tensor:\n",
  537. " \"\"\"Divides the data into bsz separate sequences, removing extra elements\n",
  538. " that wouldn't cleanly fit.\n",
  539. "\n",
  540. " Args:\n",
  541. " data: Tensor, shape [N]\n",
  542. " bsz: int, batch size\n",
  543. "\n",
  544. " Returns:\n",
  545. " Tensor of shape [N // bsz, bsz]\n",
  546. " \"\"\"\n",
  547. " seq_len = data.size(0) // bsz\n",
  548. " data = data[:seq_len * bsz]\n",
  549. " data = data.view(bsz, seq_len).t().contiguous()\n",
  550. " return data.to(device)"
  551. ]
  552. },
  553. {
  554. "cell_type": "code",
  555. "execution_count": 25,
  556. "id": "a4def1ac",
  557. "metadata": {},
  558. "outputs": [],
  559. "source": [
  560. "batch_size = 20\n",
  561. "eval_batch_size = 10\n",
  562. "train_data = batchify(train_data, batch_size) # shape [seq_len, batch_size]\n",
  563. "val_data = batchify(val_data, eval_batch_size)\n",
  564. "test_data = batchify(test_data, eval_batch_size)"
  565. ]
  566. },
  567. {
  568. "cell_type": "markdown",
  569. "id": "c658cb42",
  570. "metadata": {},
  571. "source": [
  572. "### define function to get batch"
  573. ]
  574. },
  575. {
  576. "cell_type": "code",
  577. "execution_count": 26,
  578. "id": "4ab5b8fd",
  579. "metadata": {},
  580. "outputs": [],
  581. "source": [
  582. "bptt = 35\n",
  583. "def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:\n",
  584. " \"\"\"\n",
  585. " Args:\n",
  586. " source: Tensor, shape [full_seq_len, batch_size]\n",
  587. " i: int\n",
  588. "\n",
  589. " Returns:\n",
  590. " tuple (data, target), where data has shape [seq_len, batch_size] and\n",
  591. " target has shape [seq_len * batch_size]\n",
  592. " \"\"\"\n",
  593. " seq_len = min(bptt, len(source) - 1 - i)\n",
  594. " data = source[i:i+seq_len]\n",
  595. " target = source[i+1:i+1+seq_len].reshape(-1)\n",
  596. " return data, target"
  597. ]
  598. },
  599. {
  600. "cell_type": "markdown",
  601. "id": "d6392484",
  602. "metadata": {},
  603. "source": [
  604. "### define parameters and init model"
  605. ]
  606. },
  607. {
  608. "cell_type": "code",
  609. "execution_count": 27,
  610. "id": "c53764da",
  611. "metadata": {},
  612. "outputs": [],
  613. "source": [
  614. "ntokens = len(vocab) # size of vocabulary\n",
  615. "emsize = 200 # embedding dimension\n",
  616. "d_hid = 200 # dimension of the feedforward network model in nn.TransformerEncoder\n",
  617. "nlayers = 2 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder\n",
  618. "nhead = 2 # number of heads in nn.MultiheadAttention\n",
  619. "dropout = 0.2 # dropout probability\n",
  620. "model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)"
  621. ]
  622. },
  623. {
  624. "cell_type": "markdown",
  625. "id": "7fb67d72",
  626. "metadata": {},
  627. "source": [
  628. "### init optimizer, loss, scheduler etc."
  629. ]
  630. },
  631. {
  632. "cell_type": "code",
  633. "execution_count": 28,
  634. "id": "ddaa1d64",
  635. "metadata": {},
  636. "outputs": [],
  637. "source": [
  638. "import copy\n",
  639. "import time\n",
  640. "\n",
  641. "criterion = nn.CrossEntropyLoss()\n",
  642. "lr = 5.0 # learning rate\n",
  643. "optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
  644. "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)"
  645. ]
  646. },
  647. {
  648. "cell_type": "markdown",
  649. "id": "dda19446",
  650. "metadata": {},
  651. "source": [
  652. "### define train function"
  653. ]
  654. },
  655. {
  656. "cell_type": "code",
  657. "execution_count": 29,
  658. "id": "50ab3fb6",
  659. "metadata": {},
  660. "outputs": [],
  661. "source": [
  662. "def train(model: nn.Module) -> None:\n",
  663. " model.train() # turn on train mode\n",
  664. " total_loss = 0.\n",
  665. " log_interval = 200\n",
  666. " start_time = time.time()\n",
  667. " src_mask = generate_square_subsequent_mask(bptt).to(device)\n",
  668. "\n",
  669. " num_batches = len(train_data) // bptt\n",
  670. " for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n",
  671. " data, targets = get_batch(train_data, i)\n",
  672. " seq_len = data.size(0)\n",
  673. " if seq_len != bptt: # only on last batch\n",
  674. " src_mask = src_mask[:seq_len, :seq_len]\n",
  675. " output = model(data, src_mask)\n",
  676. " loss = criterion(output.view(-1, ntokens), targets)\n",
  677. "\n",
  678. " optimizer.zero_grad()\n",
  679. " loss.backward()\n",
  680. " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n",
  681. " optimizer.step()\n",
  682. "\n",
  683. " total_loss += loss.item()\n",
  684. " if batch % log_interval == 0 and batch > 0:\n",
  685. " lr = scheduler.get_last_lr()[0]\n",
  686. " ms_per_batch = (time.time() - start_time) * 1000 / log_interval\n",
  687. " cur_loss = total_loss / log_interval\n",
  688. " ppl = math.exp(cur_loss)\n",
  689. " print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '\n",
  690. " f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n",
  691. " f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n",
  692. " total_loss = 0\n",
  693. " start_time = time.time()"
  694. ]
  695. },
  696. {
  697. "cell_type": "markdown",
  698. "id": "9756c092",
  699. "metadata": {},
  700. "source": [
  701. "### define evaluate function"
  702. ]
  703. },
  704. {
  705. "cell_type": "code",
  706. "execution_count": 30,
  707. "id": "3d179bb0",
  708. "metadata": {},
  709. "outputs": [],
  710. "source": [
  711. "def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n",
  712. " model.eval() # turn on evaluation mode\n",
  713. " total_loss = 0.\n",
  714. " src_mask = generate_square_subsequent_mask(bptt).to(device)\n",
  715. " with torch.no_grad():\n",
  716. " for i in range(0, eval_data.size(0) - 1, bptt):\n",
  717. " data, targets = get_batch(eval_data, i)\n",
  718. " seq_len = data.size(0)\n",
  719. " if seq_len != bptt:\n",
  720. " src_mask = src_mask[:seq_len, :seq_len]\n",
  721. " output = model(data, src_mask)\n",
  722. " output_flat = output.view(-1, ntokens)\n",
  723. " total_loss += seq_len * criterion(output_flat, targets).item()\n",
  724. " return total_loss / (len(eval_data) - 1)"
  725. ]
  726. },
  727. {
  728. "cell_type": "markdown",
  729. "id": "5a959f09",
  730. "metadata": {},
  731. "source": [
  732. "### now we can start training the model while saving best one"
  733. ]
  734. },
  735. {
  736. "cell_type": "code",
  737. "execution_count": 31,
  738. "id": "09c4d4ce",
  739. "metadata": {
  740. "scrolled": true
  741. },
  742. "outputs": [
  743. {
  744. "name": "stdout",
  745. "output_type": "stream",
  746. "text": [
  747. "| epoch 1 | 200/13484 batches | lr 5.00 | ms/batch 116.24 | loss 9.27 | ppl 10651.55\n",
  748. "| epoch 1 | 400/13484 batches | lr 5.00 | ms/batch 114.02 | loss 7.49 | ppl 1787.62\n",
  749. "| epoch 1 | 600/13484 batches | lr 5.00 | ms/batch 114.33 | loss 6.83 | ppl 923.44\n",
  750. "| epoch 1 | 800/13484 batches | lr 5.00 | ms/batch 114.54 | loss 6.54 | ppl 693.98\n",
  751. "| epoch 1 | 1000/13484 batches | lr 5.00 | ms/batch 114.73 | loss 6.33 | ppl 563.29\n",
  752. "| epoch 1 | 1200/13484 batches | lr 5.00 | ms/batch 114.85 | loss 6.18 | ppl 485.05\n",
  753. "| epoch 1 | 1400/13484 batches | lr 5.00 | ms/batch 114.91 | loss 6.09 | ppl 440.69\n",
  754. "| epoch 1 | 1600/13484 batches | lr 5.00 | ms/batch 115.00 | loss 6.06 | ppl 428.38\n",
  755. "| epoch 1 | 1800/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.98 | ppl 397.07\n",
  756. "| epoch 1 | 2000/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.91 | ppl 369.13\n",
  757. "| epoch 1 | 2200/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.89 | ppl 360.14\n",
  758. "| epoch 1 | 2400/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.83 | ppl 341.10\n",
  759. "| epoch 1 | 2600/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.78 | ppl 322.33\n",
  760. "| epoch 1 | 2800/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.80 | ppl 329.27\n",
  761. "| epoch 1 | 3000/13484 batches | lr 5.00 | ms/batch 115.12 | loss 5.77 | ppl 321.64\n",
  762. "| epoch 1 | 3200/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.71 | ppl 303.37\n",
  763. "| epoch 1 | 3400/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.74 | ppl 311.04\n",
  764. "| epoch 1 | 3600/13484 batches | lr 5.00 | ms/batch 115.15 | loss 5.70 | ppl 299.44\n",
  765. "| epoch 1 | 3800/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.68 | ppl 292.67\n",
  766. "| epoch 1 | 4000/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.59 | ppl 268.70\n",
  767. "| epoch 1 | 4200/13484 batches | lr 5.00 | ms/batch 115.19 | loss 5.62 | ppl 275.23\n",
  768. "| epoch 1 | 4400/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.63 | ppl 277.51\n",
  769. "| epoch 1 | 4600/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.66 | ppl 286.99\n",
  770. "| epoch 1 | 4800/13484 batches | lr 5.00 | ms/batch 115.30 | loss 5.62 | ppl 276.08\n",
  771. "| epoch 1 | 5000/13484 batches | lr 5.00 | ms/batch 115.15 | loss 5.61 | ppl 272.68\n",
  772. "| epoch 1 | 5200/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.59 | ppl 268.83\n",
  773. "| epoch 1 | 5400/13484 batches | lr 5.00 | ms/batch 115.29 | loss 5.55 | ppl 257.80\n",
  774. "| epoch 1 | 5600/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.57 | ppl 261.32\n",
  775. "| epoch 1 | 5800/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.55 | ppl 257.06\n",
  776. "| epoch 1 | 6000/13484 batches | lr 5.00 | ms/batch 115.26 | loss 5.56 | ppl 259.08\n",
  777. "| epoch 1 | 6200/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.57 | ppl 262.89\n",
  778. "| epoch 1 | 6400/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.54 | ppl 254.66\n",
  779. "| epoch 1 | 6600/13484 batches | lr 5.00 | ms/batch 115.27 | loss 5.57 | ppl 263.01\n",
  780. "| epoch 1 | 6800/13484 batches | lr 5.00 | ms/batch 115.21 | loss 5.51 | ppl 246.13\n",
  781. "| epoch 1 | 7000/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.57 | ppl 261.50\n",
  782. "| epoch 1 | 7200/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.51 | ppl 247.48\n",
  783. "| epoch 1 | 7400/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.50 | ppl 245.45\n",
  784. "| epoch 1 | 7600/13484 batches | lr 5.00 | ms/batch 115.26 | loss 5.51 | ppl 247.79\n",
  785. "| epoch 1 | 7800/13484 batches | lr 5.00 | ms/batch 115.27 | loss 5.50 | ppl 245.74\n",
  786. "| epoch 1 | 8000/13484 batches | lr 5.00 | ms/batch 115.33 | loss 5.48 | ppl 240.49\n",
  787. "| epoch 1 | 8200/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.48 | ppl 238.87\n",
  788. "| epoch 1 | 8400/13484 batches | lr 5.00 | ms/batch 115.27 | loss 5.49 | ppl 241.45\n",
  789. "| epoch 1 | 8600/13484 batches | lr 5.00 | ms/batch 115.23 | loss 5.47 | ppl 236.88\n",
  790. "| epoch 1 | 8800/13484 batches | lr 5.00 | ms/batch 115.28 | loss 5.47 | ppl 236.31\n",
  791. "| epoch 1 | 9000/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.48 | ppl 240.63\n",
  792. "| epoch 1 | 9200/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.48 | ppl 239.53\n",
  793. "| epoch 1 | 9400/13484 batches | lr 5.00 | ms/batch 115.29 | loss 5.48 | ppl 238.75\n",
  794. "| epoch 1 | 9600/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.43 | ppl 229.14\n",
  795. "| epoch 1 | 9800/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.42 | ppl 226.49\n",
  796. "| epoch 1 | 10000/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.47 | ppl 236.79\n",
  797. "| epoch 1 | 10200/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.41 | ppl 223.98\n",
  798. "| epoch 1 | 10400/13484 batches | lr 5.00 | ms/batch 115.16 | loss 5.39 | ppl 219.63\n",
  799. "| epoch 1 | 10600/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.42 | ppl 225.37\n",
  800. "| epoch 1 | 10800/13484 batches | lr 5.00 | ms/batch 115.30 | loss 5.45 | ppl 232.44\n",
  801. "| epoch 1 | 11000/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.45 | ppl 232.12\n",
  802. "| epoch 1 | 11200/13484 batches | lr 5.00 | ms/batch 115.21 | loss 5.43 | ppl 228.71\n",
  803. "| epoch 1 | 11400/13484 batches | lr 5.00 | ms/batch 115.32 | loss 5.38 | ppl 216.73\n",
  804. "| epoch 1 | 11600/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.41 | ppl 222.68\n",
  805. "| epoch 1 | 11800/13484 batches | lr 5.00 | ms/batch 115.28 | loss 5.39 | ppl 218.39\n",
  806. "| epoch 1 | 12000/13484 batches | lr 5.00 | ms/batch 115.17 | loss 5.44 | ppl 229.94\n",
  807. "| epoch 1 | 12200/13484 batches | lr 5.00 | ms/batch 115.20 | loss 5.36 | ppl 213.26\n",
  808. "| epoch 1 | 12400/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.38 | ppl 217.41\n",
  809. "| epoch 1 | 12600/13484 batches | lr 5.00 | ms/batch 115.24 | loss 5.40 | ppl 222.35\n",
  810. "| epoch 1 | 12800/13484 batches | lr 5.00 | ms/batch 115.22 | loss 5.41 | ppl 224.63\n",
  811. "| epoch 1 | 13000/13484 batches | lr 5.00 | ms/batch 115.29 | loss 5.40 | ppl 220.79\n",
  812. "| epoch 1 | 13200/13484 batches | lr 5.00 | ms/batch 115.16 | loss 5.41 | ppl 223.58\n",
  813. "| epoch 1 | 13400/13484 batches | lr 5.00 | ms/batch 115.25 | loss 5.42 | ppl 225.49\n",
  814. "-----------------------------------------------------------------------------------------\n",
  815. "| end of epoch 1 | time: 1625.43s | valid loss 5.35 | valid ppl 210.93\n",
  816. "-----------------------------------------------------------------------------------------\n",
  817. "| epoch 2 | 200/13484 batches | lr 4.75 | ms/batch 115.84 | loss 5.44 | ppl 229.80\n",
  818. "| epoch 2 | 400/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.38 | ppl 216.74\n",
  819. "| epoch 2 | 600/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.35 | ppl 211.15\n",
  820. "| epoch 2 | 800/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.37 | ppl 215.74\n",
  821. "| epoch 2 | 1000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.35 | ppl 210.96\n",
  822. "| epoch 2 | 1200/13484 batches | lr 4.75 | ms/batch 115.17 | loss 5.33 | ppl 207.12\n",
  823. "| epoch 2 | 1400/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.34 | ppl 208.70\n",
  824. "| epoch 2 | 1600/13484 batches | lr 4.75 | ms/batch 115.14 | loss 5.36 | ppl 212.80\n",
  825. "| epoch 2 | 1800/13484 batches | lr 4.75 | ms/batch 115.11 | loss 5.35 | ppl 209.96\n",
  826. "| epoch 2 | 2000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.32 | ppl 203.54\n",
  827. "| epoch 2 | 2200/13484 batches | lr 4.75 | ms/batch 115.15 | loss 5.33 | ppl 205.82\n",
  828. "| epoch 2 | 2400/13484 batches | lr 4.75 | ms/batch 115.23 | loss 5.34 | ppl 208.95\n",
  829. "| epoch 2 | 2600/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.29 | ppl 199.16\n",
  830. "| epoch 2 | 2800/13484 batches | lr 4.75 | ms/batch 115.16 | loss 5.34 | ppl 208.19\n",
  831. "| epoch 2 | 3000/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.33 | ppl 205.88\n",
  832. "| epoch 2 | 3200/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.29 | ppl 198.11\n",
  833. "| epoch 2 | 3400/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.37 | ppl 214.29\n",
  834. "| epoch 2 | 3600/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.31 | ppl 202.72\n",
  835. "| epoch 2 | 3800/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.32 | ppl 203.84\n",
  836. "| epoch 2 | 4000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.24 | ppl 189.14\n",
  837. "| epoch 2 | 4200/13484 batches | lr 4.75 | ms/batch 115.15 | loss 5.28 | ppl 196.95\n",
  838. "| epoch 2 | 4400/13484 batches | lr 4.75 | ms/batch 115.17 | loss 5.29 | ppl 198.84\n"
  839. ]
  840. },
  841. {
  842. "name": "stdout",
  843. "output_type": "stream",
  844. "text": [
  845. "| epoch 2 | 4600/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.35 | ppl 210.15\n",
  846. "| epoch 2 | 4800/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.32 | ppl 204.37\n",
  847. "| epoch 2 | 5000/13484 batches | lr 4.75 | ms/batch 115.29 | loss 5.33 | ppl 205.42\n",
  848. "| epoch 2 | 5200/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.31 | ppl 201.44\n",
  849. "| epoch 2 | 5400/13484 batches | lr 4.75 | ms/batch 115.23 | loss 5.30 | ppl 200.48\n",
  850. "| epoch 2 | 5600/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.29 | ppl 197.76\n",
  851. "| epoch 2 | 5800/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.34 | ppl 207.65\n",
  852. "| epoch 2 | 6000/13484 batches | lr 4.75 | ms/batch 115.11 | loss 5.32 | ppl 204.89\n",
  853. "| epoch 2 | 6200/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.34 | ppl 209.24\n",
  854. "| epoch 2 | 6400/13484 batches | lr 4.75 | ms/batch 115.14 | loss 5.31 | ppl 201.48\n",
  855. "| epoch 2 | 6600/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.36 | ppl 212.87\n",
  856. "| epoch 2 | 6800/13484 batches | lr 4.75 | ms/batch 115.13 | loss 5.29 | ppl 198.41\n",
  857. "| epoch 2 | 7000/13484 batches | lr 4.75 | ms/batch 115.16 | loss 5.35 | ppl 211.39\n",
  858. "| epoch 2 | 7200/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.30 | ppl 199.94\n",
  859. "| epoch 2 | 7400/13484 batches | lr 4.75 | ms/batch 115.11 | loss 5.30 | ppl 200.81\n",
  860. "| epoch 2 | 7600/13484 batches | lr 4.75 | ms/batch 115.31 | loss 5.35 | ppl 211.20\n",
  861. "| epoch 2 | 7800/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.31 | ppl 201.93\n",
  862. "| epoch 2 | 8000/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.29 | ppl 198.24\n",
  863. "| epoch 2 | 8200/13484 batches | lr 4.75 | ms/batch 115.14 | loss 5.27 | ppl 194.75\n",
  864. "| epoch 2 | 8400/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.29 | ppl 198.48\n",
  865. "| epoch 2 | 8600/13484 batches | lr 4.75 | ms/batch 115.13 | loss 5.29 | ppl 198.11\n",
  866. "| epoch 2 | 8800/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.34 | ppl 207.62\n",
  867. "| epoch 2 | 9000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.33 | ppl 205.55\n",
  868. "| epoch 2 | 9200/13484 batches | lr 4.75 | ms/batch 115.27 | loss 5.33 | ppl 206.24\n",
  869. "| epoch 2 | 9400/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.31 | ppl 201.81\n",
  870. "| epoch 2 | 9600/13484 batches | lr 4.75 | ms/batch 115.25 | loss 5.29 | ppl 198.63\n",
  871. "| epoch 2 | 9800/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.26 | ppl 192.87\n",
  872. "| epoch 2 | 10000/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.30 | ppl 199.77\n",
  873. "| epoch 2 | 10200/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.25 | ppl 191.30\n",
  874. "| epoch 2 | 10400/13484 batches | lr 4.75 | ms/batch 115.24 | loss 5.22 | ppl 184.78\n",
  875. "| epoch 2 | 10600/13484 batches | lr 4.75 | ms/batch 115.20 | loss 5.27 | ppl 194.07\n",
  876. "| epoch 2 | 10800/13484 batches | lr 4.75 | ms/batch 115.23 | loss 5.30 | ppl 200.53\n",
  877. "| epoch 2 | 11000/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.29 | ppl 198.68\n",
  878. "| epoch 2 | 11200/13484 batches | lr 4.75 | ms/batch 115.21 | loss 5.28 | ppl 196.43\n",
  879. "| epoch 2 | 11400/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.23 | ppl 186.61\n",
  880. "| epoch 2 | 11600/13484 batches | lr 4.75 | ms/batch 115.13 | loss 5.27 | ppl 195.11\n",
  881. "| epoch 2 | 11800/13484 batches | lr 4.75 | ms/batch 115.19 | loss 5.23 | ppl 186.19\n",
  882. "| epoch 2 | 12000/13484 batches | lr 4.75 | ms/batch 115.22 | loss 5.31 | ppl 202.19\n",
  883. "| epoch 2 | 12200/13484 batches | lr 4.75 | ms/batch 115.18 | loss 5.22 | ppl 184.46\n",
  884. "| epoch 2 | 12400/13484 batches | lr 4.75 | ms/batch 115.32 | loss 5.23 | ppl 187.26\n",
  885. "| epoch 2 | 12600/13484 batches | lr 4.75 | ms/batch 115.31 | loss 5.25 | ppl 189.65\n",
  886. "| epoch 2 | 12800/13484 batches | lr 4.75 | ms/batch 115.26 | loss 5.28 | ppl 196.25\n",
  887. "| epoch 2 | 13000/13484 batches | lr 4.75 | ms/batch 115.35 | loss 5.28 | ppl 196.31\n",
  888. "| epoch 2 | 13200/13484 batches | lr 4.75 | ms/batch 115.32 | loss 5.28 | ppl 195.61\n",
  889. "| epoch 2 | 13400/13484 batches | lr 4.75 | ms/batch 115.27 | loss 5.28 | ppl 195.80\n",
  890. "-----------------------------------------------------------------------------------------\n",
  891. "| end of epoch 2 | time: 1625.71s | valid loss 5.24 | valid ppl 188.48\n",
  892. "-----------------------------------------------------------------------------------------\n",
  893. "| epoch 3 | 200/13484 batches | lr 4.51 | ms/batch 115.84 | loss 5.32 | ppl 205.41\n",
  894. "| epoch 3 | 400/13484 batches | lr 4.51 | ms/batch 115.17 | loss 5.28 | ppl 195.56\n",
  895. "| epoch 3 | 600/13484 batches | lr 4.51 | ms/batch 115.12 | loss 5.22 | ppl 184.23\n",
  896. "| epoch 3 | 800/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.23 | ppl 187.41\n",
  897. "| epoch 3 | 1000/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.23 | ppl 186.77\n",
  898. "| epoch 3 | 1200/13484 batches | lr 4.51 | ms/batch 115.14 | loss 5.22 | ppl 184.68\n",
  899. "| epoch 3 | 1400/13484 batches | lr 4.51 | ms/batch 115.18 | loss 5.20 | ppl 181.17\n",
  900. "| epoch 3 | 1600/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.25 | ppl 191.20\n",
  901. "| epoch 3 | 1800/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.23 | ppl 186.87\n",
  902. "| epoch 3 | 2000/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.19 | ppl 180.16\n",
  903. "| epoch 3 | 2200/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.21 | ppl 183.82\n",
  904. "| epoch 3 | 2400/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.21 | ppl 182.76\n",
  905. "| epoch 3 | 2600/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.19 | ppl 180.25\n",
  906. "| epoch 3 | 2800/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.22 | ppl 185.75\n",
  907. "| epoch 3 | 3000/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.21 | ppl 183.06\n",
  908. "| epoch 3 | 3200/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.17 | ppl 176.28\n",
  909. "| epoch 3 | 3400/13484 batches | lr 4.51 | ms/batch 115.16 | loss 5.24 | ppl 187.88\n",
  910. "| epoch 3 | 3600/13484 batches | lr 4.51 | ms/batch 115.16 | loss 5.21 | ppl 182.87\n",
  911. "| epoch 3 | 3800/13484 batches | lr 4.51 | ms/batch 115.18 | loss 5.21 | ppl 182.52\n",
  912. "| epoch 3 | 4000/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.15 | ppl 172.43\n",
  913. "| epoch 3 | 4200/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.18 | ppl 177.72\n",
  914. "| epoch 3 | 4400/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.19 | ppl 179.22\n",
  915. "| epoch 3 | 4600/13484 batches | lr 4.51 | ms/batch 115.26 | loss 5.24 | ppl 187.99\n",
  916. "| epoch 3 | 4800/13484 batches | lr 4.51 | ms/batch 115.17 | loss 5.24 | ppl 188.20\n",
  917. "| epoch 3 | 5000/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.22 | ppl 184.24\n",
  918. "| epoch 3 | 5200/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.21 | ppl 184.01\n",
  919. "| epoch 3 | 5400/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.19 | ppl 178.94\n",
  920. "| epoch 3 | 5600/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.19 | ppl 180.15\n",
  921. "| epoch 3 | 5800/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.20 | ppl 181.24\n",
  922. "| epoch 3 | 6000/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.22 | ppl 184.08\n",
  923. "| epoch 3 | 6200/13484 batches | lr 4.51 | ms/batch 115.29 | loss 5.24 | ppl 187.77\n",
  924. "| epoch 3 | 6400/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.21 | ppl 182.36\n",
  925. "| epoch 3 | 6600/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.25 | ppl 190.52\n",
  926. "| epoch 3 | 6800/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.20 | ppl 180.56\n",
  927. "| epoch 3 | 7000/13484 batches | lr 4.51 | ms/batch 115.14 | loss 5.23 | ppl 186.73\n",
  928. "| epoch 3 | 7200/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.19 | ppl 179.90\n",
  929. "| epoch 3 | 7400/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.21 | ppl 182.43\n",
  930. "| epoch 3 | 7600/13484 batches | lr 4.51 | ms/batch 115.09 | loss 5.20 | ppl 181.48\n",
  931. "| epoch 3 | 7800/13484 batches | lr 4.51 | ms/batch 115.26 | loss 5.22 | ppl 185.25\n",
  932. "| epoch 3 | 8000/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.18 | ppl 178.05\n",
  933. "| epoch 3 | 8200/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.18 | ppl 178.41\n",
  934. "| epoch 3 | 8400/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.20 | ppl 181.07\n",
  935. "| epoch 3 | 8600/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.20 | ppl 182.10\n",
  936. "| epoch 3 | 8800/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.18 | ppl 177.86\n"
  937. ]
  938. },
  939. {
  940. "name": "stdout",
  941. "output_type": "stream",
  942. "text": [
  943. "| epoch 3 | 9000/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.21 | ppl 182.88\n",
  944. "| epoch 3 | 9200/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.21 | ppl 183.47\n",
  945. "| epoch 3 | 9400/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.22 | ppl 184.25\n",
  946. "| epoch 3 | 9600/13484 batches | lr 4.51 | ms/batch 115.27 | loss 5.18 | ppl 177.24\n",
  947. "| epoch 3 | 9800/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.16 | ppl 174.24\n",
  948. "| epoch 3 | 10000/13484 batches | lr 4.51 | ms/batch 115.22 | loss 5.21 | ppl 182.29\n",
  949. "| epoch 3 | 10200/13484 batches | lr 4.51 | ms/batch 115.13 | loss 5.17 | ppl 175.34\n",
  950. "| epoch 3 | 10400/13484 batches | lr 4.51 | ms/batch 115.28 | loss 5.14 | ppl 170.79\n",
  951. "| epoch 3 | 10600/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.17 | ppl 176.55\n",
  952. "| epoch 3 | 10800/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.22 | ppl 185.77\n",
  953. "| epoch 3 | 11000/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.19 | ppl 179.38\n",
  954. "| epoch 3 | 11200/13484 batches | lr 4.51 | ms/batch 115.30 | loss 5.19 | ppl 179.59\n",
  955. "| epoch 3 | 11400/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.14 | ppl 171.38\n",
  956. "| epoch 3 | 11600/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.18 | ppl 178.51\n",
  957. "| epoch 3 | 11800/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.16 | ppl 174.52\n",
  958. "| epoch 3 | 12000/13484 batches | lr 4.51 | ms/batch 115.25 | loss 5.20 | ppl 181.28\n",
  959. "| epoch 3 | 12200/13484 batches | lr 4.51 | ms/batch 115.13 | loss 5.14 | ppl 170.89\n",
  960. "| epoch 3 | 12400/13484 batches | lr 4.51 | ms/batch 115.24 | loss 5.14 | ppl 169.88\n",
  961. "| epoch 3 | 12600/13484 batches | lr 4.51 | ms/batch 115.20 | loss 5.15 | ppl 172.67\n",
  962. "| epoch 3 | 12800/13484 batches | lr 4.51 | ms/batch 115.19 | loss 5.18 | ppl 176.89\n",
  963. "| epoch 3 | 13000/13484 batches | lr 4.51 | ms/batch 115.23 | loss 5.19 | ppl 179.90\n",
  964. "| epoch 3 | 13200/13484 batches | lr 4.51 | ms/batch 115.29 | loss 5.17 | ppl 175.90\n",
  965. "| epoch 3 | 13400/13484 batches | lr 4.51 | ms/batch 115.21 | loss 5.20 | ppl 182.07\n",
  966. "-----------------------------------------------------------------------------------------\n",
  967. "| end of epoch 3 | time: 1625.86s | valid loss 5.19 | valid ppl 178.94\n",
  968. "-----------------------------------------------------------------------------------------\n",
  969. "| epoch 4 | 200/13484 batches | lr 4.29 | ms/batch 115.82 | loss 5.22 | ppl 184.03\n",
  970. "| epoch 4 | 400/13484 batches | lr 4.29 | ms/batch 115.24 | loss 5.16 | ppl 174.84\n",
  971. "| epoch 4 | 600/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.14 | ppl 170.18\n",
  972. "| epoch 4 | 800/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.15 | ppl 171.64\n",
  973. "| epoch 4 | 1000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.14 | ppl 171.26\n",
  974. "| epoch 4 | 1200/13484 batches | lr 4.29 | ms/batch 115.14 | loss 5.14 | ppl 171.18\n",
  975. "| epoch 4 | 1400/13484 batches | lr 4.29 | ms/batch 115.13 | loss 5.12 | ppl 166.55\n",
  976. "| epoch 4 | 1600/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.17 | ppl 176.35\n",
  977. "| epoch 4 | 1800/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.15 | ppl 172.34\n",
  978. "| epoch 4 | 2000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.13 | ppl 169.46\n",
  979. "| epoch 4 | 2200/13484 batches | lr 4.29 | ms/batch 115.23 | loss 5.16 | ppl 173.74\n",
  980. "| epoch 4 | 2400/13484 batches | lr 4.29 | ms/batch 115.14 | loss 5.14 | ppl 170.76\n",
  981. "| epoch 4 | 2600/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.11 | ppl 165.36\n",
  982. "| epoch 4 | 2800/13484 batches | lr 4.29 | ms/batch 115.14 | loss 5.15 | ppl 173.15\n",
  983. "| epoch 4 | 3000/13484 batches | lr 4.29 | ms/batch 115.15 | loss 5.14 | ppl 171.39\n",
  984. "| epoch 4 | 3200/13484 batches | lr 4.29 | ms/batch 115.27 | loss 5.10 | ppl 164.27\n",
  985. "| epoch 4 | 3400/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.16 | ppl 174.64\n",
  986. "| epoch 4 | 3600/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.13 | ppl 168.98\n",
  987. "| epoch 4 | 3800/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.12 | ppl 167.42\n",
  988. "| epoch 4 | 4000/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.08 | ppl 161.02\n",
  989. "| epoch 4 | 4200/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.11 | ppl 165.33\n",
  990. "| epoch 4 | 4400/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.11 | ppl 165.61\n",
  991. "| epoch 4 | 4600/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.16 | ppl 173.90\n",
  992. "| epoch 4 | 4800/13484 batches | lr 4.29 | ms/batch 115.27 | loss 5.15 | ppl 172.81\n",
  993. "| epoch 4 | 5000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.14 | ppl 169.98\n",
  994. "| epoch 4 | 5200/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.13 | ppl 168.94\n",
  995. "| epoch 4 | 5400/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.10 | ppl 164.28\n",
  996. "| epoch 4 | 5600/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.12 | ppl 167.23\n",
  997. "| epoch 4 | 5800/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.12 | ppl 167.63\n",
  998. "| epoch 4 | 6000/13484 batches | lr 4.29 | ms/batch 115.34 | loss 5.14 | ppl 170.26\n",
  999. "| epoch 4 | 6200/13484 batches | lr 4.29 | ms/batch 115.31 | loss 5.18 | ppl 177.13\n",
  1000. "| epoch 4 | 6400/13484 batches | lr 4.29 | ms/batch 115.27 | loss 5.13 | ppl 169.45\n",
  1001. "| epoch 4 | 6600/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.16 | ppl 174.83\n",
  1002. "| epoch 4 | 6800/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.11 | ppl 165.20\n",
  1003. "| epoch 4 | 7000/13484 batches | lr 4.29 | ms/batch 115.19 | loss 5.16 | ppl 174.72\n",
  1004. "| epoch 4 | 7200/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.12 | ppl 167.83\n",
  1005. "| epoch 4 | 7400/13484 batches | lr 4.29 | ms/batch 115.29 | loss 5.12 | ppl 167.13\n",
  1006. "| epoch 4 | 7600/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.13 | ppl 168.29\n",
  1007. "| epoch 4 | 7800/13484 batches | lr 4.29 | ms/batch 115.30 | loss 5.12 | ppl 167.88\n",
  1008. "| epoch 4 | 8000/13484 batches | lr 4.29 | ms/batch 115.20 | loss 5.11 | ppl 165.65\n",
  1009. "| epoch 4 | 8200/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.10 | ppl 164.16\n",
  1010. "| epoch 4 | 8400/13484 batches | lr 4.29 | ms/batch 115.29 | loss 5.12 | ppl 166.71\n",
  1011. "| epoch 4 | 8600/13484 batches | lr 4.29 | ms/batch 115.32 | loss 5.14 | ppl 169.91\n",
  1012. "| epoch 4 | 8800/13484 batches | lr 4.29 | ms/batch 115.30 | loss 5.11 | ppl 166.00\n",
  1013. "| epoch 4 | 9000/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.13 | ppl 169.67\n",
  1014. "| epoch 4 | 9200/13484 batches | lr 4.29 | ms/batch 115.31 | loss 5.13 | ppl 169.46\n",
  1015. "| epoch 4 | 9400/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.15 | ppl 171.85\n",
  1016. "| epoch 4 | 9600/13484 batches | lr 4.29 | ms/batch 115.29 | loss 5.11 | ppl 165.01\n",
  1017. "| epoch 4 | 9800/13484 batches | lr 4.29 | ms/batch 115.21 | loss 5.09 | ppl 162.51\n",
  1018. "| epoch 4 | 10000/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.12 | ppl 167.81\n",
  1019. "| epoch 4 | 10200/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.10 | ppl 163.43\n",
  1020. "| epoch 4 | 10400/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.07 | ppl 158.79\n",
  1021. "| epoch 4 | 10600/13484 batches | lr 4.29 | ms/batch 115.30 | loss 5.10 | ppl 163.54\n",
  1022. "| epoch 4 | 10800/13484 batches | lr 4.29 | ms/batch 115.39 | loss 5.12 | ppl 167.03\n",
  1023. "| epoch 4 | 11000/13484 batches | lr 4.29 | ms/batch 115.33 | loss 5.11 | ppl 166.04\n",
  1024. "| epoch 4 | 11200/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.11 | ppl 165.67\n",
  1025. "| epoch 4 | 11400/13484 batches | lr 4.29 | ms/batch 115.26 | loss 5.06 | ppl 157.42\n",
  1026. "| epoch 4 | 11600/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.10 | ppl 164.17\n",
  1027. "| epoch 4 | 11800/13484 batches | lr 4.29 | ms/batch 115.36 | loss 5.07 | ppl 159.41\n",
  1028. "| epoch 4 | 12000/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.13 | ppl 168.33\n",
  1029. "| epoch 4 | 12200/13484 batches | lr 4.29 | ms/batch 115.24 | loss 5.05 | ppl 155.52\n",
  1030. "| epoch 4 | 12400/13484 batches | lr 4.29 | ms/batch 115.25 | loss 5.07 | ppl 159.62\n",
  1031. "| epoch 4 | 12600/13484 batches | lr 4.29 | ms/batch 115.24 | loss 5.09 | ppl 161.65\n",
  1032. "| epoch 4 | 12800/13484 batches | lr 4.29 | ms/batch 115.22 | loss 5.10 | ppl 164.49\n",
  1033. "| epoch 4 | 13000/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.10 | ppl 163.47\n",
  1034. "| epoch 4 | 13200/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.09 | ppl 162.89\n"
  1035. ]
  1036. },
  1037. {
  1038. "name": "stdout",
  1039. "output_type": "stream",
  1040. "text": [
  1041. "| epoch 4 | 13400/13484 batches | lr 4.29 | ms/batch 115.28 | loss 5.12 | ppl 166.66\n",
  1042. "-----------------------------------------------------------------------------------------\n",
  1043. "| end of epoch 4 | time: 1626.36s | valid loss 5.13 | valid ppl 168.54\n",
  1044. "-----------------------------------------------------------------------------------------\n",
  1045. "| epoch 5 | 200/13484 batches | lr 4.07 | ms/batch 115.90 | loss 5.14 | ppl 170.65\n",
  1046. "| epoch 5 | 400/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.09 | ppl 163.18\n",
  1047. "| epoch 5 | 600/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.07 | ppl 159.22\n",
  1048. "| epoch 5 | 800/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.08 | ppl 160.60\n",
  1049. "| epoch 5 | 1000/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.08 | ppl 160.49\n",
  1050. "| epoch 5 | 1200/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.07 | ppl 158.86\n",
  1051. "| epoch 5 | 1400/13484 batches | lr 4.07 | ms/batch 115.14 | loss 5.06 | ppl 156.88\n",
  1052. "| epoch 5 | 1600/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.10 | ppl 164.68\n",
  1053. "| epoch 5 | 1800/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.09 | ppl 161.68\n",
  1054. "| epoch 5 | 2000/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.05 | ppl 156.19\n",
  1055. "| epoch 5 | 2200/13484 batches | lr 4.07 | ms/batch 115.16 | loss 5.06 | ppl 157.65\n",
  1056. "| epoch 5 | 2400/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.05 | ppl 156.29\n",
  1057. "| epoch 5 | 2600/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.04 | ppl 155.08\n",
  1058. "| epoch 5 | 2800/13484 batches | lr 4.07 | ms/batch 115.12 | loss 5.08 | ppl 160.79\n",
  1059. "| epoch 5 | 3000/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.06 | ppl 157.93\n",
  1060. "| epoch 5 | 3200/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.03 | ppl 153.59\n",
  1061. "| epoch 5 | 3400/13484 batches | lr 4.07 | ms/batch 115.24 | loss 5.10 | ppl 164.69\n",
  1062. "| epoch 5 | 3600/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.07 | ppl 159.67\n",
  1063. "| epoch 5 | 3800/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.05 | ppl 156.33\n",
  1064. "| epoch 5 | 4000/13484 batches | lr 4.07 | ms/batch 115.30 | loss 5.00 | ppl 148.52\n",
  1065. "| epoch 5 | 4200/13484 batches | lr 4.07 | ms/batch 115.16 | loss 5.03 | ppl 153.04\n",
  1066. "| epoch 5 | 4400/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.04 | ppl 155.12\n",
  1067. "| epoch 5 | 4600/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.09 | ppl 162.86\n",
  1068. "| epoch 5 | 4800/13484 batches | lr 4.07 | ms/batch 115.21 | loss 5.07 | ppl 159.17\n",
  1069. "| epoch 5 | 5000/13484 batches | lr 4.07 | ms/batch 115.27 | loss 5.06 | ppl 157.50\n",
  1070. "| epoch 5 | 5200/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.06 | ppl 157.70\n",
  1071. "| epoch 5 | 5400/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.04 | ppl 154.31\n",
  1072. "| epoch 5 | 5600/13484 batches | lr 4.07 | ms/batch 115.18 | loss 5.05 | ppl 156.47\n",
  1073. "| epoch 5 | 5800/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.06 | ppl 157.43\n",
  1074. "| epoch 5 | 6000/13484 batches | lr 4.07 | ms/batch 115.15 | loss 5.07 | ppl 159.33\n",
  1075. "| epoch 5 | 6200/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.09 | ppl 163.19\n",
  1076. "| epoch 5 | 6400/13484 batches | lr 4.07 | ms/batch 115.19 | loss 5.07 | ppl 159.77\n",
  1077. "| epoch 5 | 6600/13484 batches | lr 4.07 | ms/batch 115.28 | loss 5.09 | ppl 163.17\n",
  1078. "| epoch 5 | 6800/13484 batches | lr 4.07 | ms/batch 115.11 | loss 5.03 | ppl 153.03\n",
  1079. "| epoch 5 | 7000/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.09 | ppl 161.90\n",
  1080. "| epoch 5 | 7200/13484 batches | lr 4.07 | ms/batch 115.21 | loss 5.06 | ppl 156.90\n",
  1081. "| epoch 5 | 7400/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.07 | ppl 159.02\n",
  1082. "| epoch 5 | 7600/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.05 | ppl 156.02\n",
  1083. "| epoch 5 | 7800/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.06 | ppl 157.20\n",
  1084. "| epoch 5 | 8000/13484 batches | lr 4.07 | ms/batch 115.20 | loss 5.04 | ppl 154.56\n",
  1085. "| epoch 5 | 8200/13484 batches | lr 4.07 | ms/batch 115.20 | loss 5.03 | ppl 152.46\n",
  1086. "| epoch 5 | 8400/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.06 | ppl 157.62\n",
  1087. "| epoch 5 | 8600/13484 batches | lr 4.07 | ms/batch 115.28 | loss 5.07 | ppl 158.74\n",
  1088. "| epoch 5 | 8800/13484 batches | lr 4.07 | ms/batch 115.30 | loss 5.04 | ppl 154.53\n",
  1089. "| epoch 5 | 9000/13484 batches | lr 4.07 | ms/batch 115.31 | loss 5.06 | ppl 157.02\n",
  1090. "| epoch 5 | 9200/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.07 | ppl 159.14\n",
  1091. "| epoch 5 | 9400/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.07 | ppl 159.15\n",
  1092. "| epoch 5 | 9600/13484 batches | lr 4.07 | ms/batch 115.24 | loss 5.04 | ppl 153.89\n",
  1093. "| epoch 5 | 9800/13484 batches | lr 4.07 | ms/batch 115.27 | loss 5.02 | ppl 151.96\n",
  1094. "| epoch 5 | 10000/13484 batches | lr 4.07 | ms/batch 115.24 | loss 5.05 | ppl 156.58\n",
  1095. "| epoch 5 | 10200/13484 batches | lr 4.07 | ms/batch 115.30 | loss 5.02 | ppl 152.10\n",
  1096. "| epoch 5 | 10400/13484 batches | lr 4.07 | ms/batch 115.33 | loss 4.98 | ppl 146.11\n",
  1097. "| epoch 5 | 10600/13484 batches | lr 4.07 | ms/batch 115.26 | loss 5.03 | ppl 153.27\n",
  1098. "| epoch 5 | 10800/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.06 | ppl 157.44\n",
  1099. "| epoch 5 | 11000/13484 batches | lr 4.07 | ms/batch 115.33 | loss 5.05 | ppl 156.34\n",
  1100. "| epoch 5 | 11200/13484 batches | lr 4.07 | ms/batch 115.22 | loss 5.04 | ppl 154.36\n",
  1101. "| epoch 5 | 11400/13484 batches | lr 4.07 | ms/batch 115.27 | loss 5.00 | ppl 148.51\n",
  1102. "| epoch 5 | 11600/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.03 | ppl 153.09\n",
  1103. "| epoch 5 | 11800/13484 batches | lr 4.07 | ms/batch 115.26 | loss 5.00 | ppl 148.85\n",
  1104. "| epoch 5 | 12000/13484 batches | lr 4.07 | ms/batch 115.25 | loss 5.06 | ppl 156.93\n",
  1105. "| epoch 5 | 12200/13484 batches | lr 4.07 | ms/batch 115.22 | loss 4.98 | ppl 145.89\n",
  1106. "| epoch 5 | 12400/13484 batches | lr 4.07 | ms/batch 115.20 | loss 5.00 | ppl 148.86\n",
  1107. "| epoch 5 | 12600/13484 batches | lr 4.07 | ms/batch 115.33 | loss 5.02 | ppl 151.32\n",
  1108. "| epoch 5 | 12800/13484 batches | lr 4.07 | ms/batch 115.23 | loss 5.04 | ppl 154.42\n",
  1109. "| epoch 5 | 13000/13484 batches | lr 4.07 | ms/batch 115.28 | loss 5.03 | ppl 152.95\n",
  1110. "| epoch 5 | 13200/13484 batches | lr 4.07 | ms/batch 115.17 | loss 5.03 | ppl 153.49\n",
  1111. "| epoch 5 | 13400/13484 batches | lr 4.07 | ms/batch 115.35 | loss 5.05 | ppl 155.92\n",
  1112. "-----------------------------------------------------------------------------------------\n",
  1113. "| end of epoch 5 | time: 1625.93s | valid loss 5.10 | valid ppl 164.06\n",
  1114. "-----------------------------------------------------------------------------------------\n",
  1115. "| epoch 6 | 200/13484 batches | lr 3.87 | ms/batch 115.82 | loss 5.07 | ppl 159.79\n",
  1116. "| epoch 6 | 400/13484 batches | lr 3.87 | ms/batch 115.22 | loss 5.04 | ppl 153.75\n",
  1117. "| epoch 6 | 600/13484 batches | lr 3.87 | ms/batch 115.14 | loss 5.01 | ppl 150.20\n",
  1118. "| epoch 6 | 800/13484 batches | lr 3.87 | ms/batch 115.25 | loss 5.01 | ppl 150.24\n",
  1119. "| epoch 6 | 1000/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.01 | ppl 149.42\n",
  1120. "| epoch 6 | 1200/13484 batches | lr 3.87 | ms/batch 115.09 | loss 5.01 | ppl 150.28\n",
  1121. "| epoch 6 | 1400/13484 batches | lr 3.87 | ms/batch 115.18 | loss 5.00 | ppl 148.53\n",
  1122. "| epoch 6 | 1600/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.05 | ppl 156.45\n",
  1123. "| epoch 6 | 1800/13484 batches | lr 3.87 | ms/batch 115.17 | loss 5.02 | ppl 151.97\n",
  1124. "| epoch 6 | 2000/13484 batches | lr 3.87 | ms/batch 115.14 | loss 5.00 | ppl 147.68\n",
  1125. "| epoch 6 | 2200/13484 batches | lr 3.87 | ms/batch 115.22 | loss 5.00 | ppl 148.99\n",
  1126. "| epoch 6 | 2400/13484 batches | lr 3.87 | ms/batch 115.19 | loss 5.00 | ppl 147.82\n",
  1127. "| epoch 6 | 2600/13484 batches | lr 3.87 | ms/batch 115.19 | loss 4.98 | ppl 145.20\n",
  1128. "| epoch 6 | 2800/13484 batches | lr 3.87 | ms/batch 115.20 | loss 5.02 | ppl 152.00\n",
  1129. "| epoch 6 | 3000/13484 batches | lr 3.87 | ms/batch 115.20 | loss 5.01 | ppl 149.24\n",
  1130. "| epoch 6 | 3200/13484 batches | lr 3.87 | ms/batch 115.23 | loss 4.98 | ppl 145.09\n",
  1131. "| epoch 6 | 3400/13484 batches | lr 3.87 | ms/batch 115.35 | loss 5.03 | ppl 153.68\n",
  1132. "| epoch 6 | 3600/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.01 | ppl 149.34\n"
  1133. ]
  1134. },
  1135. {
  1136. "name": "stdout",
  1137. "output_type": "stream",
  1138. "text": [
  1139. "| epoch 6 | 3800/13484 batches | lr 3.87 | ms/batch 115.20 | loss 5.00 | ppl 148.07\n",
  1140. "| epoch 6 | 4000/13484 batches | lr 3.87 | ms/batch 115.32 | loss 4.94 | ppl 140.04\n",
  1141. "| epoch 6 | 4200/13484 batches | lr 3.87 | ms/batch 115.21 | loss 4.97 | ppl 144.64\n",
  1142. "| epoch 6 | 4400/13484 batches | lr 3.87 | ms/batch 115.20 | loss 4.99 | ppl 146.48\n",
  1143. "| epoch 6 | 4600/13484 batches | lr 3.87 | ms/batch 115.18 | loss 5.03 | ppl 153.49\n",
  1144. "| epoch 6 | 4800/13484 batches | lr 3.87 | ms/batch 115.30 | loss 5.01 | ppl 150.20\n",
  1145. "| epoch 6 | 5000/13484 batches | lr 3.87 | ms/batch 115.24 | loss 5.00 | ppl 148.23\n",
  1146. "| epoch 6 | 5200/13484 batches | lr 3.87 | ms/batch 115.22 | loss 5.00 | ppl 148.51\n",
  1147. "| epoch 6 | 5400/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.98 | ppl 145.45\n",
  1148. "| epoch 6 | 5600/13484 batches | lr 3.87 | ms/batch 115.25 | loss 4.99 | ppl 146.84\n",
  1149. "| epoch 6 | 5800/13484 batches | lr 3.87 | ms/batch 115.29 | loss 4.99 | ppl 147.24\n",
  1150. "| epoch 6 | 6000/13484 batches | lr 3.87 | ms/batch 115.19 | loss 5.01 | ppl 150.09\n",
  1151. "| epoch 6 | 6200/13484 batches | lr 3.87 | ms/batch 115.21 | loss 5.03 | ppl 152.86\n",
  1152. "| epoch 6 | 6400/13484 batches | lr 3.87 | ms/batch 115.17 | loss 5.02 | ppl 150.83\n",
  1153. "| epoch 6 | 6600/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.05 | ppl 155.25\n",
  1154. "| epoch 6 | 6800/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.00 | ppl 148.10\n",
  1155. "| epoch 6 | 7000/13484 batches | lr 3.87 | ms/batch 115.35 | loss 5.03 | ppl 152.52\n",
  1156. "| epoch 6 | 7200/13484 batches | lr 3.87 | ms/batch 115.25 | loss 5.00 | ppl 148.62\n",
  1157. "| epoch 6 | 7400/13484 batches | lr 3.87 | ms/batch 115.30 | loss 5.00 | ppl 148.56\n",
  1158. "| epoch 6 | 7600/13484 batches | lr 3.87 | ms/batch 115.25 | loss 4.99 | ppl 147.28\n",
  1159. "| epoch 6 | 7800/13484 batches | lr 3.87 | ms/batch 115.24 | loss 5.00 | ppl 147.93\n",
  1160. "| epoch 6 | 8000/13484 batches | lr 3.87 | ms/batch 115.24 | loss 4.98 | ppl 145.76\n",
  1161. "| epoch 6 | 8200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.97 | ppl 143.39\n",
  1162. "| epoch 6 | 8400/13484 batches | lr 3.87 | ms/batch 115.24 | loss 4.99 | ppl 147.14\n",
  1163. "| epoch 6 | 8600/13484 batches | lr 3.87 | ms/batch 115.27 | loss 5.00 | ppl 148.00\n",
  1164. "| epoch 6 | 8800/13484 batches | lr 3.87 | ms/batch 115.35 | loss 4.98 | ppl 145.27\n",
  1165. "| epoch 6 | 9000/13484 batches | lr 3.87 | ms/batch 115.27 | loss 5.01 | ppl 150.06\n",
  1166. "| epoch 6 | 9200/13484 batches | lr 3.87 | ms/batch 115.21 | loss 5.01 | ppl 150.09\n",
  1167. "| epoch 6 | 9400/13484 batches | lr 3.87 | ms/batch 115.28 | loss 5.01 | ppl 150.08\n",
  1168. "| epoch 6 | 9600/13484 batches | lr 3.87 | ms/batch 115.16 | loss 4.99 | ppl 147.55\n",
  1169. "| epoch 6 | 9800/13484 batches | lr 3.87 | ms/batch 115.27 | loss 4.97 | ppl 143.67\n",
  1170. "| epoch 6 | 10000/13484 batches | lr 3.87 | ms/batch 115.20 | loss 4.99 | ppl 147.66\n",
  1171. "| epoch 6 | 10200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.95 | ppl 141.61\n",
  1172. "| epoch 6 | 10400/13484 batches | lr 3.87 | ms/batch 115.20 | loss 4.93 | ppl 138.76\n",
  1173. "| epoch 6 | 10600/13484 batches | lr 3.87 | ms/batch 115.28 | loss 4.97 | ppl 144.59\n",
  1174. "| epoch 6 | 10800/13484 batches | lr 3.87 | ms/batch 115.23 | loss 5.01 | ppl 149.16\n",
  1175. "| epoch 6 | 11000/13484 batches | lr 3.87 | ms/batch 115.29 | loss 5.00 | ppl 148.35\n",
  1176. "| epoch 6 | 11200/13484 batches | lr 3.87 | ms/batch 115.29 | loss 5.01 | ppl 149.31\n",
  1177. "| epoch 6 | 11400/13484 batches | lr 3.87 | ms/batch 115.29 | loss 4.95 | ppl 141.26\n",
  1178. "| epoch 6 | 11600/13484 batches | lr 3.87 | ms/batch 115.34 | loss 4.98 | ppl 145.07\n",
  1179. "| epoch 6 | 11800/13484 batches | lr 3.87 | ms/batch 115.28 | loss 4.94 | ppl 140.00\n",
  1180. "| epoch 6 | 12000/13484 batches | lr 3.87 | ms/batch 115.19 | loss 5.00 | ppl 147.85\n",
  1181. "| epoch 6 | 12200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.93 | ppl 137.74\n",
  1182. "| epoch 6 | 12400/13484 batches | lr 3.87 | ms/batch 115.26 | loss 4.95 | ppl 140.89\n",
  1183. "| epoch 6 | 12600/13484 batches | lr 3.87 | ms/batch 115.38 | loss 4.97 | ppl 143.33\n",
  1184. "| epoch 6 | 12800/13484 batches | lr 3.87 | ms/batch 115.29 | loss 4.98 | ppl 145.29\n",
  1185. "| epoch 6 | 13000/13484 batches | lr 3.87 | ms/batch 115.37 | loss 4.97 | ppl 144.45\n",
  1186. "| epoch 6 | 13200/13484 batches | lr 3.87 | ms/batch 115.31 | loss 4.98 | ppl 146.13\n",
  1187. "| epoch 6 | 13400/13484 batches | lr 3.87 | ms/batch 115.33 | loss 5.00 | ppl 148.36\n",
  1188. "-----------------------------------------------------------------------------------------\n",
  1189. "| end of epoch 6 | time: 1626.41s | valid loss 5.09 | valid ppl 162.11\n",
  1190. "-----------------------------------------------------------------------------------------\n",
  1191. "| epoch 7 | 200/13484 batches | lr 3.68 | ms/batch 115.82 | loss 5.02 | ppl 151.17\n",
  1192. "| epoch 7 | 400/13484 batches | lr 3.68 | ms/batch 115.20 | loss 4.98 | ppl 144.87\n",
  1193. "| epoch 7 | 600/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.96 | ppl 142.39\n",
  1194. "| epoch 7 | 800/13484 batches | lr 3.68 | ms/batch 115.16 | loss 4.96 | ppl 142.45\n",
  1195. "| epoch 7 | 1000/13484 batches | lr 3.68 | ms/batch 115.08 | loss 4.96 | ppl 142.21\n",
  1196. "| epoch 7 | 1200/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.96 | ppl 142.21\n",
  1197. "| epoch 7 | 1400/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.94 | ppl 139.97\n",
  1198. "| epoch 7 | 1600/13484 batches | lr 3.68 | ms/batch 115.13 | loss 4.99 | ppl 146.87\n",
  1199. "| epoch 7 | 1800/13484 batches | lr 3.68 | ms/batch 115.11 | loss 4.97 | ppl 144.27\n",
  1200. "| epoch 7 | 2000/13484 batches | lr 3.68 | ms/batch 115.14 | loss 4.94 | ppl 139.63\n",
  1201. "| epoch 7 | 2200/13484 batches | lr 3.68 | ms/batch 115.13 | loss 4.94 | ppl 140.28\n",
  1202. "| epoch 7 | 2400/13484 batches | lr 3.68 | ms/batch 115.14 | loss 4.94 | ppl 140.42\n",
  1203. "| epoch 7 | 2600/13484 batches | lr 3.68 | ms/batch 115.20 | loss 4.93 | ppl 138.37\n",
  1204. "| epoch 7 | 2800/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.97 | ppl 144.51\n",
  1205. "| epoch 7 | 3000/13484 batches | lr 3.68 | ms/batch 115.22 | loss 4.95 | ppl 141.43\n",
  1206. "| epoch 7 | 3200/13484 batches | lr 3.68 | ms/batch 115.17 | loss 4.92 | ppl 137.29\n",
  1207. "| epoch 7 | 3400/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.98 | ppl 145.62\n",
  1208. "| epoch 7 | 3600/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.95 | ppl 141.60\n",
  1209. "| epoch 7 | 3800/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.94 | ppl 139.88\n",
  1210. "| epoch 7 | 4000/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.89 | ppl 133.49\n",
  1211. "| epoch 7 | 4200/13484 batches | lr 3.68 | ms/batch 115.17 | loss 4.93 | ppl 138.21\n",
  1212. "| epoch 7 | 4400/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.94 | ppl 139.14\n",
  1213. "| epoch 7 | 4600/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.98 | ppl 145.67\n",
  1214. "| epoch 7 | 4800/13484 batches | lr 3.68 | ms/batch 115.25 | loss 4.96 | ppl 143.05\n",
  1215. "| epoch 7 | 5000/13484 batches | lr 3.68 | ms/batch 115.20 | loss 4.95 | ppl 141.27\n",
  1216. "| epoch 7 | 5200/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.95 | ppl 140.78\n",
  1217. "| epoch 7 | 5400/13484 batches | lr 3.68 | ms/batch 115.35 | loss 4.93 | ppl 137.98\n",
  1218. "| epoch 7 | 5600/13484 batches | lr 3.68 | ms/batch 115.29 | loss 4.94 | ppl 139.66\n",
  1219. "| epoch 7 | 5800/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.94 | ppl 139.99\n",
  1220. "| epoch 7 | 6000/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.96 | ppl 142.34\n",
  1221. "| epoch 7 | 6200/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.99 | ppl 146.32\n",
  1222. "| epoch 7 | 6400/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.96 | ppl 142.33\n",
  1223. "| epoch 7 | 6600/13484 batches | lr 3.68 | ms/batch 115.22 | loss 4.99 | ppl 146.69\n",
  1224. "| epoch 7 | 6800/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.93 | ppl 137.90\n",
  1225. "| epoch 7 | 7000/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.98 | ppl 145.72\n",
  1226. "| epoch 7 | 7200/13484 batches | lr 3.68 | ms/batch 115.25 | loss 4.94 | ppl 140.06\n",
  1227. "| epoch 7 | 7400/13484 batches | lr 3.68 | ms/batch 115.14 | loss 4.94 | ppl 140.43\n",
  1228. "| epoch 7 | 7600/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.95 | ppl 140.71\n",
  1229. "| epoch 7 | 7800/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.94 | ppl 140.23\n",
  1230. "| epoch 7 | 8000/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.93 | ppl 138.76\n"
  1231. ]
  1232. },
  1233. {
  1234. "name": "stdout",
  1235. "output_type": "stream",
  1236. "text": [
  1237. "| epoch 7 | 8200/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.92 | ppl 136.77\n",
  1238. "| epoch 7 | 8400/13484 batches | lr 3.68 | ms/batch 115.33 | loss 4.94 | ppl 139.78\n",
  1239. "| epoch 7 | 8600/13484 batches | lr 3.68 | ms/batch 115.31 | loss 4.95 | ppl 141.01\n",
  1240. "| epoch 7 | 8800/13484 batches | lr 3.68 | ms/batch 115.18 | loss 4.93 | ppl 138.80\n",
  1241. "| epoch 7 | 9000/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.95 | ppl 141.73\n",
  1242. "| epoch 7 | 9200/13484 batches | lr 3.68 | ms/batch 115.23 | loss 4.97 | ppl 144.05\n",
  1243. "| epoch 7 | 9400/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.97 | ppl 144.66\n",
  1244. "| epoch 7 | 9600/13484 batches | lr 3.68 | ms/batch 115.26 | loss 4.93 | ppl 138.12\n",
  1245. "| epoch 7 | 9800/13484 batches | lr 3.68 | ms/batch 115.27 | loss 4.91 | ppl 135.39\n",
  1246. "| epoch 7 | 10000/13484 batches | lr 3.68 | ms/batch 115.19 | loss 4.94 | ppl 140.12\n",
  1247. "| epoch 7 | 10200/13484 batches | lr 3.68 | ms/batch 115.27 | loss 4.90 | ppl 134.73\n",
  1248. "| epoch 7 | 10400/13484 batches | lr 3.68 | ms/batch 115.29 | loss 4.87 | ppl 130.45\n",
  1249. "| epoch 7 | 10600/13484 batches | lr 3.68 | ms/batch 115.35 | loss 4.92 | ppl 137.36\n",
  1250. "| epoch 7 | 10800/13484 batches | lr 3.68 | ms/batch 115.29 | loss 4.94 | ppl 140.35\n",
  1251. "| epoch 7 | 11000/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.95 | ppl 141.22\n",
  1252. "| epoch 7 | 11200/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.94 | ppl 139.33\n",
  1253. "| epoch 7 | 11400/13484 batches | lr 3.68 | ms/batch 115.28 | loss 4.89 | ppl 133.07\n",
  1254. "| epoch 7 | 11600/13484 batches | lr 3.68 | ms/batch 115.21 | loss 4.93 | ppl 137.82\n",
  1255. "| epoch 7 | 11800/13484 batches | lr 3.68 | ms/batch 115.33 | loss 4.89 | ppl 132.51\n",
  1256. "| epoch 7 | 12000/13484 batches | lr 3.68 | ms/batch 115.32 | loss 4.94 | ppl 139.89\n",
  1257. "| epoch 7 | 12200/13484 batches | lr 3.68 | ms/batch 115.25 | loss 4.88 | ppl 131.43\n",
  1258. "| epoch 7 | 12400/13484 batches | lr 3.68 | ms/batch 115.32 | loss 4.89 | ppl 133.23\n",
  1259. "| epoch 7 | 12600/13484 batches | lr 3.68 | ms/batch 115.30 | loss 4.92 | ppl 136.69\n",
  1260. "| epoch 7 | 12800/13484 batches | lr 3.68 | ms/batch 115.27 | loss 4.94 | ppl 139.23\n",
  1261. "| epoch 7 | 13000/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.92 | ppl 136.46\n",
  1262. "| epoch 7 | 13200/13484 batches | lr 3.68 | ms/batch 115.31 | loss 4.92 | ppl 137.53\n",
  1263. "| epoch 7 | 13400/13484 batches | lr 3.68 | ms/batch 115.24 | loss 4.94 | ppl 140.23\n",
  1264. "-----------------------------------------------------------------------------------------\n",
  1265. "| end of epoch 7 | time: 1626.06s | valid loss 5.05 | valid ppl 155.94\n",
  1266. "-----------------------------------------------------------------------------------------\n",
  1267. "| epoch 8 | 200/13484 batches | lr 3.49 | ms/batch 115.85 | loss 4.97 | ppl 143.91\n",
  1268. "| epoch 8 | 400/13484 batches | lr 3.49 | ms/batch 115.20 | loss 4.93 | ppl 138.69\n",
  1269. "| epoch 8 | 600/13484 batches | lr 3.49 | ms/batch 115.25 | loss 4.92 | ppl 137.31\n",
  1270. "| epoch 8 | 800/13484 batches | lr 3.49 | ms/batch 115.21 | loss 4.91 | ppl 135.14\n",
  1271. "| epoch 8 | 1000/13484 batches | lr 3.49 | ms/batch 115.09 | loss 4.91 | ppl 136.03\n",
  1272. "| epoch 8 | 1200/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.91 | ppl 135.25\n",
  1273. "| epoch 8 | 1400/13484 batches | lr 3.49 | ms/batch 115.16 | loss 4.89 | ppl 132.45\n",
  1274. "| epoch 8 | 1600/13484 batches | lr 3.49 | ms/batch 115.19 | loss 4.94 | ppl 139.52\n",
  1275. "| epoch 8 | 1800/13484 batches | lr 3.49 | ms/batch 115.13 | loss 4.92 | ppl 136.90\n",
  1276. "| epoch 8 | 2000/13484 batches | lr 3.49 | ms/batch 115.23 | loss 4.89 | ppl 132.80\n",
  1277. "| epoch 8 | 2200/13484 batches | lr 3.49 | ms/batch 115.12 | loss 4.89 | ppl 132.74\n",
  1278. "| epoch 8 | 2400/13484 batches | lr 3.49 | ms/batch 115.25 | loss 4.90 | ppl 133.92\n",
  1279. "| epoch 8 | 2600/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.88 | ppl 131.38\n",
  1280. "| epoch 8 | 2800/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.92 | ppl 137.14\n",
  1281. "| epoch 8 | 3000/13484 batches | lr 3.49 | ms/batch 115.18 | loss 4.90 | ppl 134.47\n",
  1282. "| epoch 8 | 3200/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.87 | ppl 130.24\n",
  1283. "| epoch 8 | 3400/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.93 | ppl 139.00\n",
  1284. "| epoch 8 | 3600/13484 batches | lr 3.49 | ms/batch 115.20 | loss 4.91 | ppl 135.20\n",
  1285. "| epoch 8 | 3800/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.90 | ppl 133.96\n",
  1286. "| epoch 8 | 4000/13484 batches | lr 3.49 | ms/batch 115.19 | loss 4.84 | ppl 127.05\n",
  1287. "| epoch 8 | 4200/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.87 | ppl 130.76\n",
  1288. "| epoch 8 | 4400/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.88 | ppl 132.28\n",
  1289. "| epoch 8 | 4600/13484 batches | lr 3.49 | ms/batch 115.36 | loss 4.93 | ppl 138.46\n",
  1290. "| epoch 8 | 4800/13484 batches | lr 3.49 | ms/batch 115.18 | loss 4.91 | ppl 135.37\n",
  1291. "| epoch 8 | 5000/13484 batches | lr 3.49 | ms/batch 115.23 | loss 4.90 | ppl 134.12\n",
  1292. "| epoch 8 | 5200/13484 batches | lr 3.49 | ms/batch 115.21 | loss 4.90 | ppl 134.65\n",
  1293. "| epoch 8 | 5400/13484 batches | lr 3.49 | ms/batch 115.35 | loss 4.87 | ppl 130.93\n",
  1294. "| epoch 8 | 5600/13484 batches | lr 3.49 | ms/batch 115.35 | loss 4.89 | ppl 133.28\n",
  1295. "| epoch 8 | 5800/13484 batches | lr 3.49 | ms/batch 115.23 | loss 4.89 | ppl 132.54\n",
  1296. "| epoch 8 | 6000/13484 batches | lr 3.49 | ms/batch 115.22 | loss 4.91 | ppl 135.15\n",
  1297. "| epoch 8 | 6200/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.94 | ppl 139.25\n",
  1298. "| epoch 8 | 6400/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.91 | ppl 135.37\n",
  1299. "| epoch 8 | 6600/13484 batches | lr 3.49 | ms/batch 115.17 | loss 4.94 | ppl 139.28\n",
  1300. "| epoch 8 | 6800/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.88 | ppl 132.05\n",
  1301. "| epoch 8 | 7000/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.92 | ppl 137.41\n",
  1302. "| epoch 8 | 7200/13484 batches | lr 3.49 | ms/batch 115.34 | loss 4.90 | ppl 133.68\n",
  1303. "| epoch 8 | 7400/13484 batches | lr 3.49 | ms/batch 115.27 | loss 4.89 | ppl 133.58\n",
  1304. "| epoch 8 | 7600/13484 batches | lr 3.49 | ms/batch 115.26 | loss 4.90 | ppl 133.64\n",
  1305. "| epoch 8 | 7800/13484 batches | lr 3.49 | ms/batch 115.33 | loss 4.89 | ppl 133.55\n",
  1306. "| epoch 8 | 8000/13484 batches | lr 3.49 | ms/batch 115.17 | loss 4.88 | ppl 132.23\n",
  1307. "| epoch 8 | 8200/13484 batches | lr 3.49 | ms/batch 115.25 | loss 4.87 | ppl 129.93\n",
  1308. "| epoch 8 | 8400/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.89 | ppl 133.16\n",
  1309. "| epoch 8 | 8600/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.89 | ppl 133.49\n",
  1310. "| epoch 8 | 8800/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.88 | ppl 131.42\n",
  1311. "| epoch 8 | 9000/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.89 | ppl 133.59\n",
  1312. "| epoch 8 | 9200/13484 batches | lr 3.49 | ms/batch 115.28 | loss 4.91 | ppl 136.20\n",
  1313. "| epoch 8 | 9400/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.91 | ppl 135.54\n",
  1314. "| epoch 8 | 9600/13484 batches | lr 3.49 | ms/batch 115.32 | loss 4.88 | ppl 131.19\n",
  1315. "| epoch 8 | 9800/13484 batches | lr 3.49 | ms/batch 115.34 | loss 4.86 | ppl 128.72\n",
  1316. "| epoch 8 | 10000/13484 batches | lr 3.49 | ms/batch 115.32 | loss 4.89 | ppl 132.80\n",
  1317. "| epoch 8 | 10200/13484 batches | lr 3.49 | ms/batch 115.33 | loss 4.85 | ppl 128.25\n",
  1318. "| epoch 8 | 10400/13484 batches | lr 3.49 | ms/batch 115.35 | loss 4.83 | ppl 124.93\n",
  1319. "| epoch 8 | 10600/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.87 | ppl 130.59\n",
  1320. "| epoch 8 | 10800/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.90 | ppl 133.78\n",
  1321. "| epoch 8 | 11000/13484 batches | lr 3.49 | ms/batch 115.30 | loss 4.90 | ppl 133.75\n",
  1322. "| epoch 8 | 11200/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.89 | ppl 133.33\n",
  1323. "| epoch 8 | 11400/13484 batches | lr 3.49 | ms/batch 115.36 | loss 4.84 | ppl 126.25\n",
  1324. "| epoch 8 | 11600/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.88 | ppl 131.70\n",
  1325. "| epoch 8 | 11800/13484 batches | lr 3.49 | ms/batch 115.36 | loss 4.84 | ppl 127.09\n",
  1326. "| epoch 8 | 12000/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.89 | ppl 133.44\n",
  1327. "| epoch 8 | 12200/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.83 | ppl 124.78\n",
  1328. "| epoch 8 | 12400/13484 batches | lr 3.49 | ms/batch 115.38 | loss 4.84 | ppl 125.91\n"
  1329. ]
  1330. },
  1331. {
  1332. "name": "stdout",
  1333. "output_type": "stream",
  1334. "text": [
  1335. "| epoch 8 | 12600/13484 batches | lr 3.49 | ms/batch 115.31 | loss 4.86 | ppl 128.83\n",
  1336. "| epoch 8 | 12800/13484 batches | lr 3.49 | ms/batch 115.24 | loss 4.88 | ppl 131.60\n",
  1337. "| epoch 8 | 13000/13484 batches | lr 3.49 | ms/batch 115.33 | loss 4.87 | ppl 130.10\n",
  1338. "| epoch 8 | 13200/13484 batches | lr 3.49 | ms/batch 115.29 | loss 4.88 | ppl 131.87\n",
  1339. "| epoch 8 | 13400/13484 batches | lr 3.49 | ms/batch 115.40 | loss 4.90 | ppl 134.29\n",
  1340. "-----------------------------------------------------------------------------------------\n",
  1341. "| end of epoch 8 | time: 1626.66s | valid loss 5.00 | valid ppl 148.39\n",
  1342. "-----------------------------------------------------------------------------------------\n",
  1343. "| epoch 9 | 200/13484 batches | lr 3.32 | ms/batch 115.97 | loss 4.92 | ppl 136.72\n",
  1344. "| epoch 9 | 400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.88 | ppl 131.62\n",
  1345. "| epoch 9 | 600/13484 batches | lr 3.32 | ms/batch 115.30 | loss 4.85 | ppl 128.22\n",
  1346. "| epoch 9 | 800/13484 batches | lr 3.32 | ms/batch 115.29 | loss 4.86 | ppl 128.84\n",
  1347. "| epoch 9 | 1000/13484 batches | lr 3.32 | ms/batch 115.29 | loss 4.86 | ppl 129.65\n",
  1348. "| epoch 9 | 1200/13484 batches | lr 3.32 | ms/batch 115.21 | loss 4.86 | ppl 128.93\n",
  1349. "| epoch 9 | 1400/13484 batches | lr 3.32 | ms/batch 115.28 | loss 4.85 | ppl 127.80\n",
  1350. "| epoch 9 | 1600/13484 batches | lr 3.32 | ms/batch 115.36 | loss 4.89 | ppl 132.74\n",
  1351. "| epoch 9 | 1800/13484 batches | lr 3.32 | ms/batch 115.27 | loss 4.88 | ppl 131.14\n",
  1352. "| epoch 9 | 2000/13484 batches | lr 3.32 | ms/batch 115.32 | loss 4.84 | ppl 126.60\n",
  1353. "| epoch 9 | 2200/13484 batches | lr 3.32 | ms/batch 115.33 | loss 4.84 | ppl 126.74\n",
  1354. "| epoch 9 | 2400/13484 batches | lr 3.32 | ms/batch 115.32 | loss 4.84 | ppl 127.02\n",
  1355. "| epoch 9 | 2600/13484 batches | lr 3.32 | ms/batch 115.31 | loss 4.84 | ppl 126.21\n",
  1356. "| epoch 9 | 2800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.87 | ppl 130.53\n",
  1357. "| epoch 9 | 3000/13484 batches | lr 3.32 | ms/batch 115.31 | loss 4.85 | ppl 127.68\n",
  1358. "| epoch 9 | 3200/13484 batches | lr 3.32 | ms/batch 115.30 | loss 4.83 | ppl 125.33\n",
  1359. "| epoch 9 | 3400/13484 batches | lr 3.32 | ms/batch 115.26 | loss 4.89 | ppl 133.40\n",
  1360. "| epoch 9 | 3600/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.86 | ppl 129.20\n",
  1361. "| epoch 9 | 3800/13484 batches | lr 3.32 | ms/batch 115.39 | loss 4.85 | ppl 127.67\n",
  1362. "| epoch 9 | 4000/13484 batches | lr 3.32 | ms/batch 115.39 | loss 4.80 | ppl 121.75\n",
  1363. "| epoch 9 | 4200/13484 batches | lr 3.32 | ms/batch 115.30 | loss 4.83 | ppl 125.31\n",
  1364. "| epoch 9 | 4400/13484 batches | lr 3.32 | ms/batch 115.44 | loss 4.84 | ppl 126.39\n",
  1365. "| epoch 9 | 4600/13484 batches | lr 3.32 | ms/batch 115.29 | loss 4.88 | ppl 131.31\n",
  1366. "| epoch 9 | 4800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.86 | ppl 129.65\n",
  1367. "| epoch 9 | 5000/13484 batches | lr 3.32 | ms/batch 115.37 | loss 4.85 | ppl 128.14\n",
  1368. "| epoch 9 | 5200/13484 batches | lr 3.32 | ms/batch 115.37 | loss 4.85 | ppl 128.35\n",
  1369. "| epoch 9 | 5400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.83 | ppl 124.69\n",
  1370. "| epoch 9 | 5600/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.85 | ppl 127.26\n",
  1371. "| epoch 9 | 5800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.85 | ppl 127.16\n",
  1372. "| epoch 9 | 6000/13484 batches | lr 3.32 | ms/batch 115.45 | loss 4.86 | ppl 128.86\n",
  1373. "| epoch 9 | 6200/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.88 | ppl 131.85\n",
  1374. "| epoch 9 | 6400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.86 | ppl 129.32\n",
  1375. "| epoch 9 | 6600/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.89 | ppl 132.53\n",
  1376. "| epoch 9 | 6800/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.83 | ppl 125.33\n",
  1377. "| epoch 9 | 7000/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.88 | ppl 131.81\n",
  1378. "| epoch 9 | 7200/13484 batches | lr 3.32 | ms/batch 115.35 | loss 4.85 | ppl 127.28\n",
  1379. "| epoch 9 | 7400/13484 batches | lr 3.32 | ms/batch 115.42 | loss 4.85 | ppl 127.94\n",
  1380. "| epoch 9 | 7600/13484 batches | lr 3.32 | ms/batch 115.42 | loss 4.85 | ppl 127.51\n",
  1381. "| epoch 9 | 7800/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.85 | ppl 127.59\n",
  1382. "| epoch 9 | 8000/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.84 | ppl 126.24\n",
  1383. "| epoch 9 | 8200/13484 batches | lr 3.32 | ms/batch 115.46 | loss 4.82 | ppl 124.14\n",
  1384. "| epoch 9 | 8400/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.85 | ppl 127.32\n",
  1385. "| epoch 9 | 8600/13484 batches | lr 3.32 | ms/batch 115.37 | loss 4.86 | ppl 128.81\n",
  1386. "| epoch 9 | 8800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.83 | ppl 125.56\n",
  1387. "| epoch 9 | 9000/13484 batches | lr 3.32 | ms/batch 115.35 | loss 4.85 | ppl 128.24\n",
  1388. "| epoch 9 | 9200/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.87 | ppl 130.12\n",
  1389. "| epoch 9 | 9400/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.86 | ppl 129.31\n",
  1390. "| epoch 9 | 9600/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.84 | ppl 126.04\n",
  1391. "| epoch 9 | 9800/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.81 | ppl 122.88\n",
  1392. "| epoch 9 | 10000/13484 batches | lr 3.32 | ms/batch 115.43 | loss 4.84 | ppl 126.54\n",
  1393. "| epoch 9 | 10200/13484 batches | lr 3.32 | ms/batch 115.43 | loss 4.80 | ppl 121.48\n",
  1394. "| epoch 9 | 10400/13484 batches | lr 3.32 | ms/batch 115.35 | loss 4.78 | ppl 118.65\n",
  1395. "| epoch 9 | 10600/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.83 | ppl 124.64\n",
  1396. "| epoch 9 | 10800/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.85 | ppl 127.13\n",
  1397. "| epoch 9 | 11000/13484 batches | lr 3.32 | ms/batch 115.47 | loss 4.85 | ppl 127.77\n",
  1398. "| epoch 9 | 11200/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.84 | ppl 126.57\n",
  1399. "| epoch 9 | 11400/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.79 | ppl 120.87\n",
  1400. "| epoch 9 | 11600/13484 batches | lr 3.32 | ms/batch 115.33 | loss 4.83 | ppl 125.52\n",
  1401. "| epoch 9 | 11800/13484 batches | lr 3.32 | ms/batch 115.34 | loss 4.80 | ppl 120.94\n",
  1402. "| epoch 9 | 12000/13484 batches | lr 3.32 | ms/batch 115.43 | loss 4.85 | ppl 127.12\n",
  1403. "| epoch 9 | 12200/13484 batches | lr 3.32 | ms/batch 115.41 | loss 4.78 | ppl 119.42\n",
  1404. "| epoch 9 | 12400/13484 batches | lr 3.32 | ms/batch 115.45 | loss 4.79 | ppl 120.40\n",
  1405. "| epoch 9 | 12600/13484 batches | lr 3.32 | ms/batch 115.44 | loss 4.82 | ppl 124.30\n",
  1406. "| epoch 9 | 12800/13484 batches | lr 3.32 | ms/batch 115.40 | loss 4.84 | ppl 126.27\n",
  1407. "| epoch 9 | 13000/13484 batches | lr 3.32 | ms/batch 115.38 | loss 4.83 | ppl 125.13\n",
  1408. "| epoch 9 | 13200/13484 batches | lr 3.32 | ms/batch 115.39 | loss 4.83 | ppl 125.83\n",
  1409. "| epoch 9 | 13400/13484 batches | lr 3.32 | ms/batch 115.42 | loss 4.85 | ppl 128.06\n",
  1410. "-----------------------------------------------------------------------------------------\n",
  1411. "| end of epoch 9 | time: 1628.05s | valid loss 5.02 | valid ppl 150.80\n",
  1412. "-----------------------------------------------------------------------------------------\n",
  1413. "| epoch 10 | 200/13484 batches | lr 3.15 | ms/batch 116.03 | loss 4.88 | ppl 131.35\n",
  1414. "| epoch 10 | 400/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.84 | ppl 126.70\n",
  1415. "| epoch 10 | 600/13484 batches | lr 3.15 | ms/batch 115.47 | loss 4.82 | ppl 124.18\n",
  1416. "| epoch 10 | 800/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.82 | ppl 123.47\n",
  1417. "| epoch 10 | 1000/13484 batches | lr 3.15 | ms/batch 115.31 | loss 4.82 | ppl 124.52\n",
  1418. "| epoch 10 | 1200/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.83 | ppl 124.69\n",
  1419. "| epoch 10 | 1400/13484 batches | lr 3.15 | ms/batch 115.49 | loss 4.81 | ppl 122.50\n",
  1420. "| epoch 10 | 1600/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.85 | ppl 127.35\n",
  1421. "| epoch 10 | 1800/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.83 | ppl 124.97\n",
  1422. "| epoch 10 | 2000/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.80 | ppl 121.45\n",
  1423. "| epoch 10 | 2200/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.80 | ppl 121.97\n",
  1424. "| epoch 10 | 2400/13484 batches | lr 3.15 | ms/batch 115.37 | loss 4.80 | ppl 122.05\n",
  1425. "| epoch 10 | 2600/13484 batches | lr 3.15 | ms/batch 115.46 | loss 4.79 | ppl 120.16\n",
  1426. "| epoch 10 | 2800/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.83 | ppl 125.44\n"
  1427. ]
  1428. },
  1429. {
  1430. "name": "stdout",
  1431. "output_type": "stream",
  1432. "text": [
  1433. "| epoch 10 | 3000/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.81 | ppl 122.12\n",
  1434. "| epoch 10 | 3200/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.79 | ppl 120.18\n",
  1435. "| epoch 10 | 3400/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.84 | ppl 127.05\n",
  1436. "| epoch 10 | 3600/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.82 | ppl 123.70\n",
  1437. "| epoch 10 | 3800/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.80 | ppl 121.74\n",
  1438. "| epoch 10 | 4000/13484 batches | lr 3.15 | ms/batch 115.49 | loss 4.76 | ppl 116.53\n",
  1439. "| epoch 10 | 4200/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.78 | ppl 119.64\n",
  1440. "| epoch 10 | 4400/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.80 | ppl 121.17\n",
  1441. "| epoch 10 | 4600/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.84 | ppl 126.61\n",
  1442. "| epoch 10 | 4800/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.83 | ppl 124.71\n",
  1443. "| epoch 10 | 5000/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.81 | ppl 122.89\n",
  1444. "| epoch 10 | 5200/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.81 | ppl 123.00\n",
  1445. "| epoch 10 | 5400/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.79 | ppl 120.50\n",
  1446. "| epoch 10 | 5600/13484 batches | lr 3.15 | ms/batch 115.38 | loss 4.80 | ppl 121.56\n",
  1447. "| epoch 10 | 5800/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.80 | ppl 121.20\n",
  1448. "| epoch 10 | 6000/13484 batches | lr 3.15 | ms/batch 115.38 | loss 4.82 | ppl 123.72\n",
  1449. "| epoch 10 | 6200/13484 batches | lr 3.15 | ms/batch 115.35 | loss 4.85 | ppl 127.61\n",
  1450. "| epoch 10 | 6400/13484 batches | lr 3.15 | ms/batch 115.32 | loss 4.82 | ppl 124.04\n",
  1451. "| epoch 10 | 6600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.85 | ppl 127.34\n",
  1452. "| epoch 10 | 6800/13484 batches | lr 3.15 | ms/batch 115.38 | loss 4.80 | ppl 121.21\n",
  1453. "| epoch 10 | 7000/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.84 | ppl 126.43\n",
  1454. "| epoch 10 | 7200/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.81 | ppl 122.41\n",
  1455. "| epoch 10 | 7400/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.81 | ppl 122.46\n",
  1456. "| epoch 10 | 7600/13484 batches | lr 3.15 | ms/batch 115.37 | loss 4.80 | ppl 122.05\n",
  1457. "| epoch 10 | 7800/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.80 | ppl 121.58\n",
  1458. "| epoch 10 | 8000/13484 batches | lr 3.15 | ms/batch 115.32 | loss 4.79 | ppl 120.04\n",
  1459. "| epoch 10 | 8200/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.78 | ppl 118.97\n",
  1460. "| epoch 10 | 8400/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.80 | ppl 121.55\n",
  1461. "| epoch 10 | 8600/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.82 | ppl 123.48\n",
  1462. "| epoch 10 | 8800/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.79 | ppl 119.98\n",
  1463. "| epoch 10 | 9000/13484 batches | lr 3.15 | ms/batch 115.37 | loss 4.80 | ppl 121.60\n",
  1464. "| epoch 10 | 9200/13484 batches | lr 3.15 | ms/batch 115.50 | loss 4.82 | ppl 124.26\n",
  1465. "| epoch 10 | 9400/13484 batches | lr 3.15 | ms/batch 115.44 | loss 4.83 | ppl 124.68\n",
  1466. "| epoch 10 | 9600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.79 | ppl 120.72\n",
  1467. "| epoch 10 | 9800/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.77 | ppl 118.27\n",
  1468. "| epoch 10 | 10000/13484 batches | lr 3.15 | ms/batch 115.43 | loss 4.80 | ppl 121.83\n",
  1469. "| epoch 10 | 10200/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.76 | ppl 116.50\n",
  1470. "| epoch 10 | 10400/13484 batches | lr 3.15 | ms/batch 115.34 | loss 4.74 | ppl 114.00\n",
  1471. "| epoch 10 | 10600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.79 | ppl 119.78\n",
  1472. "| epoch 10 | 10800/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.80 | ppl 121.63\n",
  1473. "| epoch 10 | 11000/13484 batches | lr 3.15 | ms/batch 115.32 | loss 4.80 | ppl 121.41\n",
  1474. "| epoch 10 | 11200/13484 batches | lr 3.15 | ms/batch 115.52 | loss 4.80 | ppl 121.11\n",
  1475. "| epoch 10 | 11400/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.75 | ppl 115.26\n",
  1476. "| epoch 10 | 11600/13484 batches | lr 3.15 | ms/batch 115.45 | loss 4.79 | ppl 120.63\n",
  1477. "| epoch 10 | 11800/13484 batches | lr 3.15 | ms/batch 115.36 | loss 4.75 | ppl 115.77\n",
  1478. "| epoch 10 | 12000/13484 batches | lr 3.15 | ms/batch 115.48 | loss 4.80 | ppl 121.70\n",
  1479. "| epoch 10 | 12200/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.74 | ppl 114.59\n",
  1480. "| epoch 10 | 12400/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.76 | ppl 116.79\n",
  1481. "| epoch 10 | 12600/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.77 | ppl 118.25\n",
  1482. "| epoch 10 | 12800/13484 batches | lr 3.15 | ms/batch 115.40 | loss 4.79 | ppl 120.47\n",
  1483. "| epoch 10 | 13000/13484 batches | lr 3.15 | ms/batch 115.41 | loss 4.78 | ppl 119.53\n",
  1484. "| epoch 10 | 13200/13484 batches | lr 3.15 | ms/batch 115.42 | loss 4.79 | ppl 120.28\n",
  1485. "| epoch 10 | 13400/13484 batches | lr 3.15 | ms/batch 115.39 | loss 4.81 | ppl 122.17\n",
  1486. "-----------------------------------------------------------------------------------------\n",
  1487. "| end of epoch 10 | time: 1628.47s | valid loss 4.98 | valid ppl 145.33\n",
  1488. "-----------------------------------------------------------------------------------------\n"
  1489. ]
  1490. }
  1491. ],
  1492. "source": [
  1493. "best_val_loss = float('inf')\n",
  1494. "epochs = 10\n",
  1495. "best_model = None\n",
  1496. "\n",
  1497. "for epoch in range(1, epochs + 1):\n",
  1498. " epoch_start_time = time.time()\n",
  1499. " train(model)\n",
  1500. " val_loss = evaluate(model, val_data)\n",
  1501. " val_ppl = math.exp(val_loss)\n",
  1502. " elapsed = time.time() - epoch_start_time\n",
  1503. " print('-' * 89)\n",
  1504. " print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '\n",
  1505. " f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')\n",
  1506. " print('-' * 89)\n",
  1507. "\n",
  1508. " if val_loss < best_val_loss:\n",
  1509. " best_val_loss = val_loss\n",
  1510. " best_model = copy.deepcopy(model)\n",
  1511. "\n",
  1512. " scheduler.step()"
  1513. ]
  1514. },
  1515. {
  1516. "cell_type": "markdown",
  1517. "id": "f0d32419",
  1518. "metadata": {},
  1519. "source": [
  1520. "### print info about best model after training"
  1521. ]
  1522. },
  1523. {
  1524. "cell_type": "code",
  1525. "execution_count": 32,
  1526. "id": "12fdd0aa",
  1527. "metadata": {
  1528. "scrolled": true
  1529. },
  1530. "outputs": [
  1531. {
  1532. "name": "stdout",
  1533. "output_type": "stream",
  1534. "text": [
  1535. "=========================================================================================\n",
  1536. "| End of training | test loss 4.98 | test ppl 144.89\n",
  1537. "=========================================================================================\n"
  1538. ]
  1539. }
  1540. ],
  1541. "source": [
  1542. "test_loss = evaluate(best_model, test_data)\n",
  1543. "test_ppl = math.exp(test_loss)\n",
  1544. "print('=' * 89)\n",
  1545. "print(f'| End of training | test loss {test_loss:5.2f} | '\n",
  1546. " f'test ppl {test_ppl:8.2f}')\n",
  1547. "print('=' * 89)"
  1548. ]
  1549. },
  1550. {
  1551. "cell_type": "markdown",
  1552. "id": "528c9f10",
  1553. "metadata": {},
  1554. "source": [
  1555. "### save trained model to file"
  1556. ]
  1557. },
  1558. {
  1559. "cell_type": "code",
  1560. "execution_count": 33,
  1561. "id": "848af399",
  1562. "metadata": {},
  1563. "outputs": [],
  1564. "source": [
  1565. "torch.save(best_model, \"pubmed-sentencecomplete.pt\")"
  1566. ]
  1567. },
  1568. {
  1569. "cell_type": "markdown",
  1570. "id": "09df56cf",
  1571. "metadata": {},
  1572. "source": [
  1573. "## Now we can try to predict based on trained model"
  1574. ]
  1575. },
  1576. {
  1577. "cell_type": "markdown",
  1578. "id": "fe250072",
  1579. "metadata": {},
  1580. "source": [
  1581. "### obtain iterator for predict batch "
  1582. ]
  1583. },
  1584. {
  1585. "cell_type": "code",
  1586. "execution_count": 159,
  1587. "id": "afe585d6",
  1588. "metadata": {},
  1589. "outputs": [],
  1590. "source": [
  1591. "def predict_abstract_iter(batch):\n",
  1592. " for batch in batch:\n",
  1593. " yield tokenizer(batch)"
  1594. ]
  1595. },
  1596. {
  1597. "cell_type": "markdown",
  1598. "id": "b043de0a",
  1599. "metadata": {},
  1600. "source": [
  1601. "### load data into tensor for model to process"
  1602. ]
  1603. },
  1604. {
  1605. "cell_type": "code",
  1606. "execution_count": 154,
  1607. "id": "8bfaa8bd",
  1608. "metadata": {},
  1609. "outputs": [],
  1610. "source": [
  1611. "def toDataTensor(batch):\n",
  1612. " predict_generator = predict_abstract_iter(batch)\n",
  1613. " return [torch.tensor(vocab.lookup_indices(item)) for item in predict_generator]"
  1614. ]
  1615. },
  1616. {
  1617. "cell_type": "markdown",
  1618. "id": "a800ffea",
  1619. "metadata": {},
  1620. "source": [
  1621. "### check device once again (prob not needed)"
  1622. ]
  1623. },
  1624. {
  1625. "cell_type": "code",
  1626. "execution_count": 7,
  1627. "id": "6e2c35ba",
  1628. "metadata": {},
  1629. "outputs": [
  1630. {
  1631. "ename": "NameError",
  1632. "evalue": "name 'torch' is not defined",
  1633. "output_type": "error",
  1634. "traceback": [
  1635. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  1636. "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
  1637. "Cell \u001b[0;32mIn [7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m device \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 2\u001b[0m device\n",
  1638. "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
  1639. ]
  1640. }
  1641. ],
  1642. "source": [
  1643. "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
  1644. "device"
  1645. ]
  1646. },
  1647. {
  1648. "cell_type": "markdown",
  1649. "id": "bef90722",
  1650. "metadata": {},
  1651. "source": [
  1652. "### optionally load model from file if it was trained already"
  1653. ]
  1654. },
  1655. {
  1656. "cell_type": "code",
  1657. "execution_count": 50,
  1658. "id": "223eed8a",
  1659. "metadata": {},
  1660. "outputs": [
  1661. {
  1662. "data": {
  1663. "text/plain": [
  1664. "TransformerModel(\n",
  1665. " (pos_encoder): PositionalEncoding(\n",
  1666. " (dropout): Dropout(p=0.2, inplace=False)\n",
  1667. " )\n",
  1668. " (transformer_encoder): TransformerEncoder(\n",
  1669. " (layers): ModuleList(\n",
  1670. " (0): TransformerEncoderLayer(\n",
  1671. " (self_attn): MultiheadAttention(\n",
  1672. " (out_proj): NonDynamicallyQuantizableLinear(in_features=200, out_features=200, bias=True)\n",
  1673. " )\n",
  1674. " (linear1): Linear(in_features=200, out_features=200, bias=True)\n",
  1675. " (dropout): Dropout(p=0.2, inplace=False)\n",
  1676. " (linear2): Linear(in_features=200, out_features=200, bias=True)\n",
  1677. " (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
  1678. " (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
  1679. " (dropout1): Dropout(p=0.2, inplace=False)\n",
  1680. " (dropout2): Dropout(p=0.2, inplace=False)\n",
  1681. " )\n",
  1682. " (1): TransformerEncoderLayer(\n",
  1683. " (self_attn): MultiheadAttention(\n",
  1684. " (out_proj): NonDynamicallyQuantizableLinear(in_features=200, out_features=200, bias=True)\n",
  1685. " )\n",
  1686. " (linear1): Linear(in_features=200, out_features=200, bias=True)\n",
  1687. " (dropout): Dropout(p=0.2, inplace=False)\n",
  1688. " (linear2): Linear(in_features=200, out_features=200, bias=True)\n",
  1689. " (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
  1690. " (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
  1691. " (dropout1): Dropout(p=0.2, inplace=False)\n",
  1692. " (dropout2): Dropout(p=0.2, inplace=False)\n",
  1693. " )\n",
  1694. " )\n",
  1695. " )\n",
  1696. " (encoder): Embedding(163987, 200)\n",
  1697. " (decoder): Linear(in_features=200, out_features=163987, bias=True)\n",
  1698. ")"
  1699. ]
  1700. },
  1701. "execution_count": 50,
  1702. "metadata": {},
  1703. "output_type": "execute_result"
  1704. }
  1705. ],
  1706. "source": [
  1707. "best_model = torch.load(\"pubmed-sentencecomplete.pt\")\n",
  1708. "best_model.eval()"
  1709. ]
  1710. },
  1711. {
  1712. "cell_type": "markdown",
  1713. "id": "dd71bdfc",
  1714. "metadata": {},
  1715. "source": [
  1716. "### define predict function"
  1717. ]
  1718. },
  1719. {
  1720. "cell_type": "code",
  1721. "execution_count": 160,
  1722. "id": "64223e87",
  1723. "metadata": {},
  1724. "outputs": [],
  1725. "source": [
  1726. "def predict(input_line, mask, n_predictions=3):\n",
  1727. " with torch.no_grad():\n",
  1728. " output = best_model(input_line.to(device), mask) \n",
  1729. " predictions = []\n",
  1730. " for i in range(n_predictions):\n",
  1731. " next_item = output.topk(i+1)[1].view(-1)[-1].item()\n",
  1732. " predict_token_index = next_item\n",
  1733. " predictions.append(vocab.lookup_token(predict_token_index))\n",
  1734. " \n",
  1735. " return predictions"
  1736. ]
  1737. },
  1738. {
  1739. "cell_type": "markdown",
  1740. "id": "a9b7311b",
  1741. "metadata": {},
  1742. "source": [
  1743. "### define input batch "
  1744. ]
  1745. },
  1746. {
  1747. "cell_type": "code",
  1748. "execution_count": 2,
  1749. "id": "913628b4",
  1750. "metadata": {},
  1751. "outputs": [],
  1752. "source": [
  1753. "sample_batch = [\n",
  1754. " \"There is\"\n",
  1755. "]\n",
  1756. "input_batch = sample_batch"
  1757. ]
  1758. },
  1759. {
  1760. "cell_type": "markdown",
  1761. "id": "45930a71",
  1762. "metadata": {},
  1763. "source": [
  1764. "### define initial source mask for model"
  1765. ]
  1766. },
  1767. {
  1768. "cell_type": "code",
  1769. "execution_count": 3,
  1770. "id": "c4bba2a1",
  1771. "metadata": {},
  1772. "outputs": [
  1773. {
  1774. "ename": "NameError",
  1775. "evalue": "name 'generate_square_subsequent_mask' is not defined",
  1776. "output_type": "error",
  1777. "traceback": [
  1778. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  1779. "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
  1780. "Cell \u001b[0;32mIn [3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m bptt \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 2\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(bptt)\u001b[38;5;241m.\u001b[39mto(device)\n",
  1781. "\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined"
  1782. ]
  1783. }
  1784. ],
  1785. "source": [
  1786. "bptt = 2\n",
  1787. "src_mask = generate_square_subsequent_mask(bptt).to(device)"
  1788. ]
  1789. },
  1790. {
  1791. "cell_type": "markdown",
  1792. "id": "5b33b9f3",
  1793. "metadata": {},
  1794. "source": [
  1795. "### Execute prediction and display predicted values and choose continuation"
  1796. ]
  1797. },
  1798. {
  1799. "cell_type": "code",
  1800. "execution_count": 4,
  1801. "id": "b2895698",
  1802. "metadata": {},
  1803. "outputs": [],
  1804. "source": [
  1805. "def predict_loop(num_of_pred):\n",
  1806. " iteration = 0\n",
  1807. " is_terminated = False\n",
  1808. " input_batch = sample_batch\n",
  1809. " while(not is_terminated):\n",
  1810. " mask_size = bptt+(iteration) \n",
  1811. " src_mask = generate_square_subsequent_mask(mask_size).to(device)\n",
  1812. " data = toDataTensor(input_batch)\n",
  1813. " \n",
  1814. " for i, d in enumerate(data):\n",
  1815. " predictions = predict(d, src_mask, num_of_pred)\n",
  1816. " \n",
  1817. " print(\"\\n Possible continuations:\")\n",
  1818. " for j in range(len(predictions)):\n",
  1819. " print(j + 1, \": \", predictions[j])\n",
  1820. " s_index = input(input_batch[i])\n",
  1821. " if(\"e\" in s_index):\n",
  1822. " is_terminated = True\n",
  1823. " print(\"prediction stopped.\")\n",
  1824. " break\n",
  1825. "\n",
  1826. " print(\"Text is now:\")\n",
  1827. " input_batch[i] += (\" \" + predictions[int(s_index) - 1])\n",
  1828. " print(input_batch[i])\n",
  1829. "\n",
  1830. " iteration = iteration + 1"
  1831. ]
  1832. },
  1833. {
  1834. "cell_type": "code",
  1835. "execution_count": 5,
  1836. "id": "13ed9298",
  1837. "metadata": {},
  1838. "outputs": [
  1839. {
  1840. "ename": "NameError",
  1841. "evalue": "name 'generate_square_subsequent_mask' is not defined",
  1842. "output_type": "error",
  1843. "traceback": [
  1844. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  1845. "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
  1846. "Cell \u001b[0;32mIn [5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mpredict_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n",
  1847. "Cell \u001b[0;32mIn [4], line 7\u001b[0m, in \u001b[0;36mpredict_loop\u001b[0;34m(num_of_pred)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m(\u001b[38;5;129;01mnot\u001b[39;00m is_terminated):\n\u001b[1;32m 6\u001b[0m mask_size \u001b[38;5;241m=\u001b[39m bptt\u001b[38;5;241m+\u001b[39m(iteration) \n\u001b[0;32m----> 7\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(mask_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m toDataTensor(input_batch)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, d \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(data):\n",
  1848. "\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined"
  1849. ]
  1850. }
  1851. ],
  1852. "source": [
  1853. "predict_loop(3)"
  1854. ]
  1855. }
  1856. ],
  1857. "metadata": {
  1858. "kernelspec": {
  1859. "display_name": "Python 3 (ipykernel)",
  1860. "language": "python",
  1861. "name": "python3"
  1862. },
  1863. "language_info": {
  1864. "codemirror_mode": {
  1865. "name": "ipython",
  1866. "version": 3
  1867. },
  1868. "file_extension": ".py",
  1869. "mimetype": "text/x-python",
  1870. "name": "python",
  1871. "nbconvert_exporter": "python",
  1872. "pygments_lexer": "ipython3",
  1873. "version": "3.10.8"
  1874. }
  1875. },
  1876. "nbformat": 4,
  1877. "nbformat_minor": 5
  1878. }