You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

1056 lines
29 KiB

{
"cells": [
{
"cell_type": "markdown",
"id": "622dfcd6",
"metadata": {},
"source": [
"# Group 09 - Automatic Sentence Completion for PubMed"
]
},
{
"cell_type": "markdown",
"id": "5e4cec3c",
"metadata": {},
"source": [
"### Authors:\n",
"- Eric Münzberg\n",
"- Shahein Enjjar\t\n",
"- Leonard Starke"
]
},
{
"cell_type": "markdown",
"id": "ee9c1d92",
"metadata": {},
"source": [
"### Firstly try to import the data modules"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "e444b44c",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" from Bio import Entrez, Medline \n",
"except:\n",
" !pip install Bio\n",
" from Bio import Entrez, Medline \n"
]
},
{
"cell_type": "markdown",
"id": "7bf15c30",
"metadata": {},
"source": [
"### define function for loading the papers from PubMed database"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "adfb256a",
"metadata": {},
"outputs": [],
"source": [
"def getPapers(myQuery, maxPapers, myEmail =\"leonard.starke@mailbox.tu-dresden.de\"):\n",
" # Get articles from PubMed\n",
" Entrez.email =myEmail\n",
" record =Entrez.read(Entrez.esearch(db=\"pubmed\", term=myQuery, retmax=maxPapers))\n",
" idlist = record[\"IdList\"]\n",
" print(\"\\nThere are %d records for %s.\"%(len(idlist), myQuery.strip()))\n",
" records = Medline.parse(Entrez.efetch(db=\"pubmed\", id=idlist, rettype=\"medline\", retmode=\"text\"))\n",
" return list(records)"
]
},
{
"cell_type": "markdown",
"id": "46bc6298",
"metadata": {},
"source": [
"### Verify that its working"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "00481ec9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"There are 1600 records for Cancer[tiab].\n"
]
}
],
"source": [
"myQuery =\"Cancer\"+\"[tiab]\" #query in title and abstract\n",
"maxPapers = 1600 # thinkabout outsourcing params to seperate section\n",
"records = getPapers(myQuery, maxPapers)\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "56cf72de",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1600"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(records)"
]
},
{
"cell_type": "markdown",
"id": "b67747c6",
"metadata": {},
"source": [
"### Now extract abstracts from records"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "dcf5c217",
"metadata": {},
"outputs": [],
"source": [
"r_abstracts = []\n",
"for r in records:\n",
" if not (r.get('AB') is None):\n",
" r_abstracts.append(r['AB'])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "eb1fb38b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1532"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(r_abstracts)"
]
},
{
"cell_type": "markdown",
"id": "e309f6fe",
"metadata": {},
"source": [
"### Now import torch modules needed to load the data"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "c3199444",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" import torch\n",
" from torch.utils.data import Dataset \n",
" from torchtext.data import get_tokenizer\n",
"except:\n",
" !pip --default-timeout=1000 install torch\n",
" !pip --default-timeout=1000 install torchtext\n",
" import torch\n",
" from torch.utils.data import Dataset \n",
" from torchtext.data import get_tokenizer"
]
},
{
"cell_type": "markdown",
"id": "5b4007e8",
"metadata": {},
"source": [
"### Import numpy"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "daca9db6",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" import numpy as np\n",
"except:\n",
" !pip install numpy\n",
" import numpy as np\n"
]
},
{
"cell_type": "markdown",
"id": "4df1e449",
"metadata": {},
"source": [
"### define token iterators"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "8a128d3c",
"metadata": {},
"outputs": [],
"source": [
"def train_abstract_iter():\n",
" for abstract in r_abstracts[:1000]:\n",
" yield abstract"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "97e89986",
"metadata": {},
"outputs": [],
"source": [
"def val_abstract_iter():\n",
" for abstract in r_abstracts[1001:1300]:\n",
" yield abstract"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "0d6e89c4",
"metadata": {},
"outputs": [],
"source": [
"def test_abstract_iter():\n",
" for abstract in r_abstracts[1301:1542]:\n",
" yield abstract"
]
},
{
"cell_type": "markdown",
"id": "e5e9c5a2",
"metadata": {},
"source": [
"### define Tokenize function"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "0bdbc40a",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = get_tokenizer(\"basic_english\")\n",
"def tokenize_abstract_iter():\n",
" for abstract in r_abstracts:\n",
" yield tokenizer(abstract)"
]
},
{
"cell_type": "markdown",
"id": "37da40bb",
"metadata": {},
"source": [
"### Map every world to a id to store inside torch tensor"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "a438ab1f",
"metadata": {},
"outputs": [],
"source": [
"from torchtext.vocab import build_vocab_from_iterator\n",
"token_generator = tokenize_abstract_iter()\n",
"vocab = build_vocab_from_iterator(token_generator, specials=['<unk>'])\n",
"vocab.set_default_index(vocab['<unk>'])\n"
]
},
{
"cell_type": "markdown",
"id": "221bdc48",
"metadata": {},
"source": [
"### now convert to tensor\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "0e5bc361",
"metadata": {},
"outputs": [],
"source": [
"def data_process(tokens_iter):\n",
" \"\"\"Converts raw text into a flat Tensor.\"\"\"\n",
" data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in tokens_iter]\n",
" return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "dfd7400d",
"metadata": {},
"outputs": [],
"source": [
"train_generator = train_abstract_iter()\n",
"val_generator = val_abstract_iter()\n",
"test_generator = test_abstract_iter()\n",
"train_data = data_process(train_generator)\n",
"val_data = data_process(val_generator)\n",
"test_data = data_process(test_generator)"
]
},
{
"cell_type": "markdown",
"id": "c49a2734",
"metadata": {},
"source": [
"### check gpu"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "c155ee31",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "79b2d248",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device"
]
},
{
"cell_type": "markdown",
"id": "2150ba71",
"metadata": {},
"source": [
"### define model"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "a33d722f",
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"from typing import Tuple\n",
"\n",
"import torch\n",
"from torch import nn, Tensor\n",
"import torch.nn.functional as F\n",
"from torch.nn import TransformerEncoder, TransformerEncoderLayer\n",
"from torch.utils.data import dataset\n",
"\n",
"class TransformerModel(nn.Module):\n",
"\n",
" def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,\n",
" nlayers: int, dropout: float = 0.5):\n",
" super().__init__()\n",
" self.model_type = 'Transformer'\n",
" self.pos_encoder = PositionalEncoding(d_model, dropout)\n",
" encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)\n",
" self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n",
" self.encoder = nn.Embedding(ntoken, d_model)\n",
" self.d_model = d_model\n",
" self.decoder = nn.Linear(d_model, ntoken)\n",
"\n",
" self.init_weights()\n",
"\n",
" def init_weights(self) -> None:\n",
" initrange = 0.1\n",
" self.encoder.weight.data.uniform_(-initrange, initrange)\n",
" self.decoder.bias.data.zero_()\n",
" self.decoder.weight.data.uniform_(-initrange, initrange)\n",
"\n",
" def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:\n",
" \"\"\"\n",
" Args:\n",
" src: Tensor, shape [seq_len, batch_size]\n",
" src_mask: Tensor, shape [seq_len, seq_len]\n",
"\n",
" Returns:\n",
" output Tensor of shape [seq_len, batch_size, ntoken]\n",
" \"\"\"\n",
" src = self.encoder(src) * math.sqrt(self.d_model)\n",
" src = self.pos_encoder(src)\n",
" output = self.transformer_encoder(src, src_mask)\n",
" output = self.decoder(output)\n",
" return output\n",
"\n",
"\n",
"def generate_square_subsequent_mask(sz: int) -> Tensor:\n",
" \"\"\"Generates an upper-triangular matrix of -inf, with zeros on diag.\"\"\"\n",
" return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)"
]
},
{
"cell_type": "markdown",
"id": "3b78cc08",
"metadata": {},
"source": [
"### define pos encoder"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "c2f6d33b",
"metadata": {},
"outputs": [],
"source": [
"class PositionalEncoding(nn.Module):\n",
"\n",
" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):\n",
" super().__init__()\n",
" self.dropout = nn.Dropout(p=dropout)\n",
"\n",
" position = torch.arange(max_len).unsqueeze(1)\n",
" div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n",
" pe = torch.zeros(max_len, 1, d_model)\n",
" pe[:, 0, 0::2] = torch.sin(position * div_term)\n",
" pe[:, 0, 1::2] = torch.cos(position * div_term)\n",
" self.register_buffer('pe', pe)\n",
"\n",
" def forward(self, x: Tensor) -> Tensor:\n",
" \"\"\"\n",
" Args:\n",
" x: Tensor, shape [seq_len, batch_size, embedding_dim]\n",
" \"\"\"\n",
" x = x + self.pe[:x.size(0)]\n",
" return self.dropout(x)\n"
]
},
{
"cell_type": "markdown",
"id": "0adefcce",
"metadata": {},
"source": [
"### define function to create batches of data and create batches"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "9e184841",
"metadata": {},
"outputs": [],
"source": [
"def batchify(data: Tensor, bsz: int) -> Tensor:\n",
" \"\"\"Divides the data into bsz separate sequences, removing extra elements\n",
" that wouldn't cleanly fit.\n",
"\n",
" Args:\n",
" data: Tensor, shape [N]\n",
" bsz: int, batch size\n",
"\n",
" Returns:\n",
" Tensor of shape [N // bsz, bsz]\n",
" \"\"\"\n",
" seq_len = data.size(0) // bsz\n",
" data = data[:seq_len * bsz]\n",
" data = data.view(bsz, seq_len).t().contiguous()\n",
" return data.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "a4def1ac",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 20\n",
"eval_batch_size = 10\n",
"train_data = batchify(train_data, batch_size) # shape [seq_len, batch_size]\n",
"val_data = batchify(val_data, eval_batch_size)\n",
"test_data = batchify(test_data, eval_batch_size)"
]
},
{
"cell_type": "markdown",
"id": "4f407ad0",
"metadata": {},
"source": [
"### define function to get batch"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "4ab5b8fd",
"metadata": {},
"outputs": [],
"source": [
"bptt = 35\n",
"def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:\n",
" \"\"\"\n",
" Args:\n",
" source: Tensor, shape [full_seq_len, batch_size]\n",
" i: int\n",
"\n",
" Returns:\n",
" tuple (data, target), where data has shape [seq_len, batch_size] and\n",
" target has shape [seq_len * batch_size]\n",
" \"\"\"\n",
" seq_len = min(bptt, len(source) - 1 - i)\n",
" data = source[i:i+seq_len]\n",
" target = source[i+1:i+1+seq_len].reshape(-1)\n",
" return data, target"
]
},
{
"cell_type": "markdown",
"id": "7ee28c38",
"metadata": {},
"source": [
"### define parameters and init model"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "c53764da",
"metadata": {},
"outputs": [],
"source": [
"ntokens = len(vocab) # size of vocabulary\n",
"emsize = 200 # embedding dimension\n",
"d_hid = 200 # dimension of the feedforward network model in nn.TransformerEncoder\n",
"nlayers = 2 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder\n",
"nhead = 2 # number of heads in nn.MultiheadAttention\n",
"dropout = 0.2 # dropout probability\n",
"model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)"
]
},
{
"cell_type": "markdown",
"id": "51f2400a",
"metadata": {},
"source": [
"### init optimizer, loss, scheduler etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9a04e07",
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"import time\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"lr = 5.0 # learning rate\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)"
]
},
{
"cell_type": "markdown",
"id": "07317af8",
"metadata": {},
"source": [
"### define train function"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "50ab3fb6",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def train(model: nn.Module) -> None:\n",
" model.train() # turn on train mode\n",
" total_loss = 0.\n",
" log_interval = 200\n",
" start_time = time.time()\n",
" src_mask = generate_square_subsequent_mask(bptt).to(device)\n",
"\n",
" num_batches = len(train_data) // bptt\n",
" for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n",
" data, targets = get_batch(train_data, i)\n",
" seq_len = data.size(0)\n",
" if seq_len != bptt: # only on last batch\n",
" src_mask = src_mask[:seq_len, :seq_len]\n",
" output = model(data, src_mask)\n",
" loss = criterion(output.view(-1, ntokens), targets)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
" if batch % log_interval == 0 and batch > 0:\n",
" lr = scheduler.get_last_lr()[0]\n",
" ms_per_batch = (time.time() - start_time) * 1000 / log_interval\n",
" cur_loss = total_loss / log_interval\n",
" ppl = math.exp(cur_loss)\n",
" print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '\n",
" f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n",
" f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n",
" total_loss = 0\n",
" start_time = time.time()"
]
},
{
"cell_type": "markdown",
"id": "23709949",
"metadata": {},
"source": [
"### define evaluate function"
]
},
{
"cell_type": "code",
"execution_count": 289,
"id": "689bd4ea",
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n",
" model.eval() # turn on evaluation mode\n",
" total_loss = 0.\n",
" src_mask = generate_square_subsequent_mask(bptt).to(device)\n",
" with torch.no_grad():\n",
" for i in range(0, eval_data.size(0) - 1, bptt):\n",
" data, targets = get_batch(eval_data, i)\n",
" seq_len = data.size(0)\n",
" if seq_len != bptt:\n",
" src_mask = src_mask[:seq_len, :seq_len]\n",
" output = model(data, src_mask)\n",
" output_flat = output.view(-1, ntokens)\n",
" total_loss += seq_len * criterion(output_flat, targets).item()\n",
" return total_loss / (len(eval_data) - 1)"
]
},
{
"cell_type": "markdown",
"id": "d7c6a1e0",
"metadata": {},
"source": [
"### now we can start training the model while saving best one"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "09c4d4ce",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 1 | 200/ 383 batches | lr 5.00 | ms/batch 63.06 | loss 8.09 | ppl 3258.10\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 1 | time: 17.05s | valid loss 6.34 | valid ppl 566.15\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 2 | 200/ 383 batches | lr 4.75 | ms/batch 19.83 | loss 6.14 | ppl 463.13\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 2 | time: 8.38s | valid loss 6.01 | valid ppl 406.56\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 3 | 200/ 383 batches | lr 4.51 | ms/batch 19.83 | loss 5.61 | ppl 273.67\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 3 | time: 8.38s | valid loss 5.95 | valid ppl 383.10\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 4 | 200/ 383 batches | lr 4.29 | ms/batch 19.89 | loss 5.25 | ppl 190.90\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 4 | time: 8.40s | valid loss 5.96 | valid ppl 386.38\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 5 | 200/ 383 batches | lr 4.07 | ms/batch 19.88 | loss 4.96 | ppl 142.55\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 5 | time: 8.40s | valid loss 5.99 | valid ppl 398.76\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 6 | 200/ 383 batches | lr 3.87 | ms/batch 19.89 | loss 4.71 | ppl 111.09\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 6 | time: 8.40s | valid loss 6.04 | valid ppl 421.64\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 7 | 200/ 383 batches | lr 3.68 | ms/batch 19.89 | loss 4.49 | ppl 89.44\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 7 | time: 8.40s | valid loss 6.11 | valid ppl 452.51\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 8 | 200/ 383 batches | lr 3.49 | ms/batch 19.92 | loss 4.30 | ppl 73.72\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 8 | time: 8.42s | valid loss 6.17 | valid ppl 479.04\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 9 | 200/ 383 batches | lr 3.32 | ms/batch 19.93 | loss 4.13 | ppl 62.43\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 9 | time: 8.42s | valid loss 6.26 | valid ppl 522.27\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 10 | 200/ 383 batches | lr 3.15 | ms/batch 19.95 | loss 3.99 | ppl 53.96\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 10 | time: 8.43s | valid loss 6.31 | valid ppl 548.35\n",
"-----------------------------------------------------------------------------------------\n"
]
}
],
"source": [
"best_val_loss = float('inf')\n",
"epochs = 10\n",
"best_model = None\n",
"\n",
"for epoch in range(1, epochs + 1):\n",
" epoch_start_time = time.time()\n",
" train(model)\n",
" val_loss = evaluate(model, val_data)\n",
" val_ppl = math.exp(val_loss)\n",
" elapsed = time.time() - epoch_start_time\n",
" print('-' * 89)\n",
" print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '\n",
" f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')\n",
" print('-' * 89)\n",
"\n",
" if val_loss < best_val_loss:\n",
" best_val_loss = val_loss\n",
" best_model = copy.deepcopy(model)\n",
"\n",
" scheduler.step()"
]
},
{
"cell_type": "markdown",
"id": "565b5aa4",
"metadata": {},
"source": [
"### print info about best model after training"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "12fdd0aa",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=========================================================================================\n",
"| End of training | test loss 5.80 | test ppl 329.59\n",
"=========================================================================================\n"
]
}
],
"source": [
"test_loss = evaluate(best_model, test_data)\n",
"test_ppl = math.exp(test_loss)\n",
"print('=' * 89)\n",
"print(f'| End of training | test loss {test_loss:5.2f} | '\n",
" f'test ppl {test_ppl:8.2f}')\n",
"print('=' * 89)"
]
},
{
"cell_type": "markdown",
"id": "12031065",
"metadata": {},
"source": [
"## Now we can try to predict based on trained model"
]
},
{
"cell_type": "markdown",
"id": "e685d3e1",
"metadata": {},
"source": [
"### define input batch "
]
},
{
"cell_type": "code",
"execution_count": 300,
"id": "cfb30fe0",
"metadata": {},
"outputs": [],
"source": [
"input_batch = [\n",
" \"The brain is\",\n",
" \"The lung is\"\n",
"]"
]
},
{
"cell_type": "markdown",
"id": "054ada71",
"metadata": {},
"source": [
"### define source mask for model"
]
},
{
"cell_type": "code",
"execution_count": 301,
"id": "305853e8",
"metadata": {},
"outputs": [],
"source": [
"bptt = 3\n",
"src_mask = generate_square_subsequent_mask(bptt).to(device)"
]
},
{
"cell_type": "markdown",
"id": "4635a73e",
"metadata": {},
"source": [
"### define iterator for predict batch and init to generator"
]
},
{
"cell_type": "code",
"execution_count": 302,
"id": "afe585d6",
"metadata": {},
"outputs": [],
"source": [
"def predict_abstract_iter():\n",
" for batch in input_batch:\n",
" yield tokenizer(batch)\n",
"predict_generator = predict_abstract_iter()"
]
},
{
"cell_type": "markdown",
"id": "1c171c8c",
"metadata": {},
"source": [
"### load data into tensor for model to process"
]
},
{
"cell_type": "code",
"execution_count": 303,
"id": "0788b045",
"metadata": {},
"outputs": [],
"source": [
"data = [torch.tensor(vocab.lookup_indices(item)) for item in predict_generator]"
]
},
{
"cell_type": "code",
"execution_count": 308,
"id": "8bfaa8bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]"
]
},
"execution_count": 308,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data"
]
},
{
"cell_type": "markdown",
"id": "99132b3d",
"metadata": {},
"source": [
"### check device once again (prob not needed)"
]
},
{
"cell_type": "code",
"execution_count": 309,
"id": "b8c50c8c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 309,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"device"
]
},
{
"cell_type": "markdown",
"id": "05766f6b",
"metadata": {},
"source": [
"### define predict function"
]
},
{
"cell_type": "code",
"execution_count": 317,
"id": "0475bcc9",
"metadata": {},
"outputs": [],
"source": [
"def predict(input_line, n_predictions=3):\n",
" print('\\n> %s' % input_line)\n",
" with torch.no_grad():\n",
" output = best_model(input_line.to(device), src_mask)\n",
"\n",
" # Get top N categories\n",
" topv, topi = output.topk(n_predictions, 1, True)\n",
" #x, y = output.topk(n_predictions, 1, True)\n",
" #print(x.shape)\n",
" #print(topv.shape)\n",
" # print(topi.shape)\n",
" predictions = []\n",
" for i in range(n_predictions):\n",
" value = topv[0][i]\n",
" v1, v2 = value.topk(1)\n",
" predict_token_index = v2.cpu().detach().numpy()\n",
" print(vocab.lookup_token(predict_token_index))\n",
" #print(category_index)\n",
" #print('(%.2f) %s' % (value, all_categories[category_index]))\n",
" predictions.append(vocab.lookup_token(predict_token_index))\n",
" return predictions"
]
},
{
"cell_type": "markdown",
"id": "8ad2f64b",
"metadata": {},
"source": [
"### Execute prediction and display predicted values"
]
},
{
"cell_type": "code",
"execution_count": 318,
"id": "55b73ea1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]\n",
"\n",
"> tensor([ 3, 555, 16])\n",
"tumors\n",
"the\n",
"the\n",
"The brain is\n",
"Possible continuations:\n",
"0 : tumors\n",
"0 : the\n",
"0 : the\n",
"\n",
"> tensor([ 3, 76, 16])\n",
"cancer\n",
"most\n",
"the\n",
"The lung is\n",
"Possible continuations:\n",
"0 : cancer\n",
"0 : most\n",
"0 : the\n"
]
}
],
"source": [
"print(data)\n",
"count = 0\n",
"num_of_pred = 3\n",
"for d in data:\n",
" predictions = predict(d, num_of_pred)\n",
" print(input_batch[count])\n",
" print(\"Possible continuations:\")\n",
" for j in range(len(predictions)):\n",
" print(i, \": \", predictions[j])\n",
" count = count + 1\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}