|
@ -29,41 +29,10 @@ |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 2, |
|
|
|
|
|
|
|
|
"execution_count": 29, |
|
|
"id": "e444b44c", |
|
|
"id": "e444b44c", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"Defaulting to user installation because normal site-packages is not writeable\n", |
|
|
|
|
|
"Collecting Bio\n", |
|
|
|
|
|
" Downloading bio-1.4.0-py3-none-any.whl (270 kB)\n", |
|
|
|
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m270.9/270.9 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", |
|
|
|
|
|
"\u001b[?25hCollecting mygene\n", |
|
|
|
|
|
" Downloading mygene-3.2.2-py2.py3-none-any.whl (5.4 kB)\n", |
|
|
|
|
|
"Collecting biopython>=1.79\n", |
|
|
|
|
|
" Downloading biopython-1.79-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB)\n", |
|
|
|
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", |
|
|
|
|
|
"\u001b[?25hRequirement already satisfied: requests in /usr/lib/python3.10/site-packages (from Bio) (2.28.1)\n", |
|
|
|
|
|
"Collecting tqdm\n", |
|
|
|
|
|
" Downloading tqdm-4.64.1-py2.py3-none-any.whl (78 kB)\n", |
|
|
|
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.5/78.5 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n", |
|
|
|
|
|
"\u001b[?25hRequirement already satisfied: numpy in /usr/lib/python3.10/site-packages (from biopython>=1.79->Bio) (1.23.3)\n", |
|
|
|
|
|
"Collecting biothings-client>=0.2.6\n", |
|
|
|
|
|
" Downloading biothings_client-0.2.6-py2.py3-none-any.whl (37 kB)\n", |
|
|
|
|
|
"Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3.10/site-packages (from requests->Bio) (3.4)\n", |
|
|
|
|
|
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/lib/python3.10/site-packages (from requests->Bio) (1.26.12)\n", |
|
|
|
|
|
"Installing collected packages: tqdm, biopython, biothings-client, mygene, Bio\n", |
|
|
|
|
|
"\u001b[33m WARNING: The script tqdm is installed in '/home/hein/.local/bin' which is not on PATH.\n", |
|
|
|
|
|
" Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\u001b[33m\n", |
|
|
|
|
|
"\u001b[0m\u001b[33m WARNING: The scripts bio and fasta_filter.py are installed in '/home/hein/.local/bin' which is not on PATH.\n", |
|
|
|
|
|
" Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\u001b[33m\n", |
|
|
|
|
|
"\u001b[0mSuccessfully installed Bio-1.4.0 biopython-1.79 biothings-client-0.2.6 mygene-3.2.2 tqdm-4.64.1\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
|
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"try:\n", |
|
|
"try:\n", |
|
|
" from Bio import Entrez, Medline \n", |
|
|
" from Bio import Entrez, Medline \n", |
|
@ -82,12 +51,12 @@ |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 7, |
|
|
|
|
|
|
|
|
"execution_count": 30, |
|
|
"id": "adfb256a", |
|
|
"id": "adfb256a", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"def getPapers(myQuery, maxPapers, myEmail =\"xxxxx@xxxxxxxx.xx\"):\n", |
|
|
|
|
|
|
|
|
"def getPapers(myQuery, maxPapers, myEmail =\"leonard.starke@mailbox.tu-dresden.de\"):\n", |
|
|
" # Get articles from PubMed\n", |
|
|
" # Get articles from PubMed\n", |
|
|
" Entrez.email =myEmail\n", |
|
|
" Entrez.email =myEmail\n", |
|
|
" record =Entrez.read(Entrez.esearch(db=\"pubmed\", term=myQuery, retmax=maxPapers))\n", |
|
|
" record =Entrez.read(Entrez.esearch(db=\"pubmed\", term=myQuery, retmax=maxPapers))\n", |
|
@ -107,7 +76,7 @@ |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 9, |
|
|
|
|
|
|
|
|
"execution_count": 31, |
|
|
"id": "00481ec9", |
|
|
"id": "00481ec9", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [ |
|
|
"outputs": [ |
|
@ -116,16 +85,37 @@ |
|
|
"output_type": "stream", |
|
|
"output_type": "stream", |
|
|
"text": [ |
|
|
"text": [ |
|
|
"\n", |
|
|
"\n", |
|
|
"There are 1000 records for Cancer[tiab].\n" |
|
|
|
|
|
|
|
|
"There are 1600 records for Cancer[tiab].\n" |
|
|
] |
|
|
] |
|
|
} |
|
|
} |
|
|
], |
|
|
], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"myQuery =\"Cancer\"+\"[tiab]\" #query in title and abstract\n", |
|
|
"myQuery =\"Cancer\"+\"[tiab]\" #query in title and abstract\n", |
|
|
"maxPapers = 1000 # thinkabout outsourcing params to seperate section\n", |
|
|
|
|
|
|
|
|
"maxPapers = 1600 # thinkabout outsourcing params to seperate section\n", |
|
|
"records = getPapers(myQuery, maxPapers)\n" |
|
|
"records = getPapers(myQuery, maxPapers)\n" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 32, |
|
|
|
|
|
"id": "56cf72de", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"1600" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": 32, |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"output_type": "execute_result" |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"len(records)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"cell_type": "markdown", |
|
|
"id": "b67747c6", |
|
|
"id": "b67747c6", |
|
@ -136,14 +126,36 @@ |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 30, |
|
|
|
|
|
|
|
|
"execution_count": 33, |
|
|
"id": "dcf5c217", |
|
|
"id": "dcf5c217", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"r_abstracts = []\n", |
|
|
"r_abstracts = []\n", |
|
|
"for r in records:\n", |
|
|
"for r in records:\n", |
|
|
" r_abstracts.append(r)" |
|
|
|
|
|
|
|
|
" if not (r.get('AB') is None):\n", |
|
|
|
|
|
" r_abstracts.append(r['AB'])" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 34, |
|
|
|
|
|
"id": "eb1fb38b", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"1532" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": 34, |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"output_type": "execute_result" |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"len(r_abstracts)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
@ -156,7 +168,7 @@ |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 11, |
|
|
|
|
|
|
|
|
"execution_count": 35, |
|
|
"id": "c3199444", |
|
|
"id": "c3199444", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
@ -164,9 +176,775 @@ |
|
|
"try:\n", |
|
|
"try:\n", |
|
|
" import torch\n", |
|
|
" import torch\n", |
|
|
" from torch.utils.data import Dataset \n", |
|
|
" from torch.utils.data import Dataset \n", |
|
|
|
|
|
" from torchtext.data import get_tokenizer\n", |
|
|
|
|
|
"except:\n", |
|
|
|
|
|
" !pip --default-timeout=1000 install torch\n", |
|
|
|
|
|
" !pip --default-timeout=1000 install torchtext\n", |
|
|
|
|
|
" import torch\n", |
|
|
|
|
|
" from torch.utils.data import Dataset \n", |
|
|
|
|
|
" from torchtext.data import get_tokenizer" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "5b4007e8", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### Import numpy" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 36, |
|
|
|
|
|
"id": "daca9db6", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"try:\n", |
|
|
|
|
|
" import numpy as np\n", |
|
|
"except:\n", |
|
|
"except:\n", |
|
|
" !pip install pytorch\n", |
|
|
|
|
|
" \n" |
|
|
|
|
|
|
|
|
" !pip install numpy\n", |
|
|
|
|
|
" import numpy as np\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "4df1e449", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define token iterators" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 37, |
|
|
|
|
|
"id": "8a128d3c", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def train_abstract_iter():\n", |
|
|
|
|
|
" for abstract in r_abstracts[:1000]:\n", |
|
|
|
|
|
" yield abstract" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 38, |
|
|
|
|
|
"id": "97e89986", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def val_abstract_iter():\n", |
|
|
|
|
|
" for abstract in r_abstracts[1001:1300]:\n", |
|
|
|
|
|
" yield abstract" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 39, |
|
|
|
|
|
"id": "0d6e89c4", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def test_abstract_iter():\n", |
|
|
|
|
|
" for abstract in r_abstracts[1301:1542]:\n", |
|
|
|
|
|
" yield abstract" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "e5e9c5a2", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define Tokenize function" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 40, |
|
|
|
|
|
"id": "0bdbc40a", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"tokenizer = get_tokenizer(\"basic_english\")\n", |
|
|
|
|
|
"def tokenize_abstract_iter():\n", |
|
|
|
|
|
" for abstract in r_abstracts:\n", |
|
|
|
|
|
" yield tokenizer(abstract)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "37da40bb", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### Map every world to a id to store inside torch tensor" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 41, |
|
|
|
|
|
"id": "a438ab1f", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"from torchtext.vocab import build_vocab_from_iterator\n", |
|
|
|
|
|
"token_generator = tokenize_abstract_iter()\n", |
|
|
|
|
|
"vocab = build_vocab_from_iterator(token_generator, specials=['<unk>'])\n", |
|
|
|
|
|
"vocab.set_default_index(vocab['<unk>'])\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "221bdc48", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### now convert to tensor\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 42, |
|
|
|
|
|
"id": "0e5bc361", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def data_process(tokens_iter):\n", |
|
|
|
|
|
" \"\"\"Converts raw text into a flat Tensor.\"\"\"\n", |
|
|
|
|
|
" data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in tokens_iter]\n", |
|
|
|
|
|
" return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 43, |
|
|
|
|
|
"id": "dfd7400d", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"train_generator = train_abstract_iter()\n", |
|
|
|
|
|
"val_generator = val_abstract_iter()\n", |
|
|
|
|
|
"test_generator = test_abstract_iter()\n", |
|
|
|
|
|
"train_data = data_process(train_generator)\n", |
|
|
|
|
|
"val_data = data_process(val_generator)\n", |
|
|
|
|
|
"test_data = data_process(test_generator)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "c49a2734", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### check gpu" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 44, |
|
|
|
|
|
"id": "c155ee31", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 45, |
|
|
|
|
|
"id": "79b2d248", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"device(type='cuda')" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": 45, |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"output_type": "execute_result" |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"device" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "2150ba71", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define model" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 46, |
|
|
|
|
|
"id": "a33d722f", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"import math\n", |
|
|
|
|
|
"from typing import Tuple\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"import torch\n", |
|
|
|
|
|
"from torch import nn, Tensor\n", |
|
|
|
|
|
"import torch.nn.functional as F\n", |
|
|
|
|
|
"from torch.nn import TransformerEncoder, TransformerEncoderLayer\n", |
|
|
|
|
|
"from torch.utils.data import dataset\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"class TransformerModel(nn.Module):\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,\n", |
|
|
|
|
|
" nlayers: int, dropout: float = 0.5):\n", |
|
|
|
|
|
" super().__init__()\n", |
|
|
|
|
|
" self.model_type = 'Transformer'\n", |
|
|
|
|
|
" self.pos_encoder = PositionalEncoding(d_model, dropout)\n", |
|
|
|
|
|
" encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)\n", |
|
|
|
|
|
" self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n", |
|
|
|
|
|
" self.encoder = nn.Embedding(ntoken, d_model)\n", |
|
|
|
|
|
" self.d_model = d_model\n", |
|
|
|
|
|
" self.decoder = nn.Linear(d_model, ntoken)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" self.init_weights()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def init_weights(self) -> None:\n", |
|
|
|
|
|
" initrange = 0.1\n", |
|
|
|
|
|
" self.encoder.weight.data.uniform_(-initrange, initrange)\n", |
|
|
|
|
|
" self.decoder.bias.data.zero_()\n", |
|
|
|
|
|
" self.decoder.weight.data.uniform_(-initrange, initrange)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:\n", |
|
|
|
|
|
" \"\"\"\n", |
|
|
|
|
|
" Args:\n", |
|
|
|
|
|
" src: Tensor, shape [seq_len, batch_size]\n", |
|
|
|
|
|
" src_mask: Tensor, shape [seq_len, seq_len]\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" Returns:\n", |
|
|
|
|
|
" output Tensor of shape [seq_len, batch_size, ntoken]\n", |
|
|
|
|
|
" \"\"\"\n", |
|
|
|
|
|
" src = self.encoder(src) * math.sqrt(self.d_model)\n", |
|
|
|
|
|
" src = self.pos_encoder(src)\n", |
|
|
|
|
|
" output = self.transformer_encoder(src, src_mask)\n", |
|
|
|
|
|
" output = self.decoder(output)\n", |
|
|
|
|
|
" return output\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"def generate_square_subsequent_mask(sz: int) -> Tensor:\n", |
|
|
|
|
|
" \"\"\"Generates an upper-triangular matrix of -inf, with zeros on diag.\"\"\"\n", |
|
|
|
|
|
" return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"id": "da8fb12b", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 47, |
|
|
|
|
|
"id": "c2f6d33b", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"class PositionalEncoding(nn.Module):\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):\n", |
|
|
|
|
|
" super().__init__()\n", |
|
|
|
|
|
" self.dropout = nn.Dropout(p=dropout)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" position = torch.arange(max_len).unsqueeze(1)\n", |
|
|
|
|
|
" div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n", |
|
|
|
|
|
" pe = torch.zeros(max_len, 1, d_model)\n", |
|
|
|
|
|
" pe[:, 0, 0::2] = torch.sin(position * div_term)\n", |
|
|
|
|
|
" pe[:, 0, 1::2] = torch.cos(position * div_term)\n", |
|
|
|
|
|
" self.register_buffer('pe', pe)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def forward(self, x: Tensor) -> Tensor:\n", |
|
|
|
|
|
" \"\"\"\n", |
|
|
|
|
|
" Args:\n", |
|
|
|
|
|
" x: Tensor, shape [seq_len, batch_size, embedding_dim]\n", |
|
|
|
|
|
" \"\"\"\n", |
|
|
|
|
|
" x = x + self.pe[:x.size(0)]\n", |
|
|
|
|
|
" return self.dropout(x)\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 48, |
|
|
|
|
|
"id": "9e184841", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def batchify(data: Tensor, bsz: int) -> Tensor:\n", |
|
|
|
|
|
" \"\"\"Divides the data into bsz separate sequences, removing extra elements\n", |
|
|
|
|
|
" that wouldn't cleanly fit.\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" Args:\n", |
|
|
|
|
|
" data: Tensor, shape [N]\n", |
|
|
|
|
|
" bsz: int, batch size\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" Returns:\n", |
|
|
|
|
|
" Tensor of shape [N // bsz, bsz]\n", |
|
|
|
|
|
" \"\"\"\n", |
|
|
|
|
|
" seq_len = data.size(0) // bsz\n", |
|
|
|
|
|
" data = data[:seq_len * bsz]\n", |
|
|
|
|
|
" data = data.view(bsz, seq_len).t().contiguous()\n", |
|
|
|
|
|
" return data.to(device)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 49, |
|
|
|
|
|
"id": "a4def1ac", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"batch_size = 20\n", |
|
|
|
|
|
"eval_batch_size = 10\n", |
|
|
|
|
|
"train_data = batchify(train_data, batch_size) # shape [seq_len, batch_size]\n", |
|
|
|
|
|
"val_data = batchify(val_data, eval_batch_size)\n", |
|
|
|
|
|
"test_data = batchify(test_data, eval_batch_size)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 50, |
|
|
|
|
|
"id": "4ab5b8fd", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"bptt = 35\n", |
|
|
|
|
|
"def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:\n", |
|
|
|
|
|
" \"\"\"\n", |
|
|
|
|
|
" Args:\n", |
|
|
|
|
|
" source: Tensor, shape [full_seq_len, batch_size]\n", |
|
|
|
|
|
" i: int\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" Returns:\n", |
|
|
|
|
|
" tuple (data, target), where data has shape [seq_len, batch_size] and\n", |
|
|
|
|
|
" target has shape [seq_len * batch_size]\n", |
|
|
|
|
|
" \"\"\"\n", |
|
|
|
|
|
" seq_len = min(bptt, len(source) - 1 - i)\n", |
|
|
|
|
|
" data = source[i:i+seq_len]\n", |
|
|
|
|
|
" target = source[i+1:i+1+seq_len].reshape(-1)\n", |
|
|
|
|
|
" return data, target" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 51, |
|
|
|
|
|
"id": "c53764da", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"ntokens = len(vocab) # size of vocabulary\n", |
|
|
|
|
|
"emsize = 200 # embedding dimension\n", |
|
|
|
|
|
"d_hid = 200 # dimension of the feedforward network model in nn.TransformerEncoder\n", |
|
|
|
|
|
"nlayers = 2 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder\n", |
|
|
|
|
|
"nhead = 2 # number of heads in nn.MultiheadAttention\n", |
|
|
|
|
|
"dropout = 0.2 # dropout probability\n", |
|
|
|
|
|
"model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 52, |
|
|
|
|
|
"id": "50ab3fb6", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"import copy\n", |
|
|
|
|
|
"import time\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"criterion = nn.CrossEntropyLoss()\n", |
|
|
|
|
|
"lr = 5.0 # learning rate\n", |
|
|
|
|
|
"optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", |
|
|
|
|
|
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"def train(model: nn.Module) -> None:\n", |
|
|
|
|
|
" model.train() # turn on train mode\n", |
|
|
|
|
|
" total_loss = 0.\n", |
|
|
|
|
|
" log_interval = 200\n", |
|
|
|
|
|
" start_time = time.time()\n", |
|
|
|
|
|
" src_mask = generate_square_subsequent_mask(bptt).to(device)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" num_batches = len(train_data) // bptt\n", |
|
|
|
|
|
" for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n", |
|
|
|
|
|
" data, targets = get_batch(train_data, i)\n", |
|
|
|
|
|
" seq_len = data.size(0)\n", |
|
|
|
|
|
" if seq_len != bptt: # only on last batch\n", |
|
|
|
|
|
" src_mask = src_mask[:seq_len, :seq_len]\n", |
|
|
|
|
|
" output = model(data, src_mask)\n", |
|
|
|
|
|
" loss = criterion(output.view(-1, ntokens), targets)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" optimizer.zero_grad()\n", |
|
|
|
|
|
" loss.backward()\n", |
|
|
|
|
|
" torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n", |
|
|
|
|
|
" optimizer.step()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" total_loss += loss.item()\n", |
|
|
|
|
|
" if batch % log_interval == 0 and batch > 0:\n", |
|
|
|
|
|
" lr = scheduler.get_last_lr()[0]\n", |
|
|
|
|
|
" ms_per_batch = (time.time() - start_time) * 1000 / log_interval\n", |
|
|
|
|
|
" cur_loss = total_loss / log_interval\n", |
|
|
|
|
|
" ppl = math.exp(cur_loss)\n", |
|
|
|
|
|
" print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '\n", |
|
|
|
|
|
" f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n", |
|
|
|
|
|
" f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n", |
|
|
|
|
|
" total_loss = 0\n", |
|
|
|
|
|
" start_time = time.time()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n", |
|
|
|
|
|
" model.eval() # turn on evaluation mode\n", |
|
|
|
|
|
" total_loss = 0.\n", |
|
|
|
|
|
" src_mask = generate_square_subsequent_mask(bptt).to(device)\n", |
|
|
|
|
|
" with torch.no_grad():\n", |
|
|
|
|
|
" for i in range(0, eval_data.size(0) - 1, bptt):\n", |
|
|
|
|
|
" data, targets = get_batch(eval_data, i)\n", |
|
|
|
|
|
" seq_len = data.size(0)\n", |
|
|
|
|
|
" if seq_len != bptt:\n", |
|
|
|
|
|
" src_mask = src_mask[:seq_len, :seq_len]\n", |
|
|
|
|
|
" output = model(data, src_mask)\n", |
|
|
|
|
|
" output_flat = output.view(-1, ntokens)\n", |
|
|
|
|
|
" total_loss += seq_len * criterion(output_flat, targets).item()\n", |
|
|
|
|
|
" return total_loss / (len(eval_data) - 1)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 53, |
|
|
|
|
|
"id": "09c4d4ce", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"| epoch 1 | 200/ 383 batches | lr 5.00 | ms/batch 63.06 | loss 8.09 | ppl 3258.10\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 1 | time: 17.05s | valid loss 6.34 | valid ppl 566.15\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| epoch 2 | 200/ 383 batches | lr 4.75 | ms/batch 19.83 | loss 6.14 | ppl 463.13\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 2 | time: 8.38s | valid loss 6.01 | valid ppl 406.56\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| epoch 3 | 200/ 383 batches | lr 4.51 | ms/batch 19.83 | loss 5.61 | ppl 273.67\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 3 | time: 8.38s | valid loss 5.95 | valid ppl 383.10\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| epoch 4 | 200/ 383 batches | lr 4.29 | ms/batch 19.89 | loss 5.25 | ppl 190.90\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 4 | time: 8.40s | valid loss 5.96 | valid ppl 386.38\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| epoch 5 | 200/ 383 batches | lr 4.07 | ms/batch 19.88 | loss 4.96 | ppl 142.55\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 5 | time: 8.40s | valid loss 5.99 | valid ppl 398.76\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| epoch 6 | 200/ 383 batches | lr 3.87 | ms/batch 19.89 | loss 4.71 | ppl 111.09\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 6 | time: 8.40s | valid loss 6.04 | valid ppl 421.64\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| epoch 7 | 200/ 383 batches | lr 3.68 | ms/batch 19.89 | loss 4.49 | ppl 89.44\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 7 | time: 8.40s | valid loss 6.11 | valid ppl 452.51\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| epoch 8 | 200/ 383 batches | lr 3.49 | ms/batch 19.92 | loss 4.30 | ppl 73.72\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 8 | time: 8.42s | valid loss 6.17 | valid ppl 479.04\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| epoch 9 | 200/ 383 batches | lr 3.32 | ms/batch 19.93 | loss 4.13 | ppl 62.43\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 9 | time: 8.42s | valid loss 6.26 | valid ppl 522.27\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| epoch 10 | 200/ 383 batches | lr 3.15 | ms/batch 19.95 | loss 3.99 | ppl 53.96\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n", |
|
|
|
|
|
"| end of epoch 10 | time: 8.43s | valid loss 6.31 | valid ppl 548.35\n", |
|
|
|
|
|
"-----------------------------------------------------------------------------------------\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"best_val_loss = float('inf')\n", |
|
|
|
|
|
"epochs = 10\n", |
|
|
|
|
|
"best_model = None\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"for epoch in range(1, epochs + 1):\n", |
|
|
|
|
|
" epoch_start_time = time.time()\n", |
|
|
|
|
|
" train(model)\n", |
|
|
|
|
|
" val_loss = evaluate(model, val_data)\n", |
|
|
|
|
|
" val_ppl = math.exp(val_loss)\n", |
|
|
|
|
|
" elapsed = time.time() - epoch_start_time\n", |
|
|
|
|
|
" print('-' * 89)\n", |
|
|
|
|
|
" print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '\n", |
|
|
|
|
|
" f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')\n", |
|
|
|
|
|
" print('-' * 89)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" if val_loss < best_val_loss:\n", |
|
|
|
|
|
" best_val_loss = val_loss\n", |
|
|
|
|
|
" best_model = copy.deepcopy(model)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" scheduler.step()" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 54, |
|
|
|
|
|
"id": "12fdd0aa", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"=========================================================================================\n", |
|
|
|
|
|
"| End of training | test loss 5.80 | test ppl 329.59\n", |
|
|
|
|
|
"=========================================================================================\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"test_loss = evaluate(best_model, test_data)\n", |
|
|
|
|
|
"test_ppl = math.exp(test_loss)\n", |
|
|
|
|
|
"print('=' * 89)\n", |
|
|
|
|
|
"print(f'| End of training | test loss {test_loss:5.2f} | '\n", |
|
|
|
|
|
" f'test ppl {test_ppl:8.2f}')\n", |
|
|
|
|
|
"print('=' * 89)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "e685d3e1", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define input batch " |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 274, |
|
|
|
|
|
"id": "cfb30fe0", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"input_batch = [\n", |
|
|
|
|
|
" \"The brain is\",\n", |
|
|
|
|
|
" \"The lung is\"\n", |
|
|
|
|
|
"]" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 275, |
|
|
|
|
|
"id": "305853e8", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"bptt = 3\n", |
|
|
|
|
|
"src_mask = generate_square_subsequent_mask(bptt).to(device)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 276, |
|
|
|
|
|
"id": "afe585d6", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def predict_abstract_iter():\n", |
|
|
|
|
|
" for batch in input_batch:\n", |
|
|
|
|
|
" yield tokenizer(batch)\n", |
|
|
|
|
|
"predict_generator = predict_abstract_iter()" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"id": "f7ac6188", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 278, |
|
|
|
|
|
"id": "0788b045", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"data = [torch.tensor(vocab.lookup_indices(item)) for item in predict_generator]" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 279, |
|
|
|
|
|
"id": "8bfaa8bd", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": 279, |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"output_type": "execute_result" |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"data" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 280, |
|
|
|
|
|
"id": "dd0e7310", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"device(type='cuda')" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": 280, |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"output_type": "execute_result" |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", |
|
|
|
|
|
"device" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 281, |
|
|
|
|
|
"id": "1728f0fd", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"for d in data:\n", |
|
|
|
|
|
" d.to(device)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 282, |
|
|
|
|
|
"id": "49d27864", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"best_model.eval()\n", |
|
|
|
|
|
"for batch in data:\n", |
|
|
|
|
|
" output = best_model(batch.to(device), src_mask)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 283, |
|
|
|
|
|
"id": "a3404169", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"result_np = []\n", |
|
|
|
|
|
"pred_np = output.cpu().detach().numpy()\n", |
|
|
|
|
|
"for el in pred_np:\n", |
|
|
|
|
|
" result_np.append(el)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 284, |
|
|
|
|
|
"id": "c7064c0c", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"[array([[-0.5258564 , 4.6108465 , 4.8358154 , ..., -0.46871045,\n", |
|
|
|
|
|
" -0.04386039, -0.13068362],\n", |
|
|
|
|
|
" [-0.35341978, 8.821181 , 9.57295 , ..., -1.2820313 ,\n", |
|
|
|
|
|
" -0.989242 , -0.15542248],\n", |
|
|
|
|
|
" [-0.39717233, 5.7111125 , 6.4295497 , ..., -0.27339834,\n", |
|
|
|
|
|
" -1.5333815 , 0.16042188]], dtype=float32),\n", |
|
|
|
|
|
" array([[-0.5291481 , 4.6452312 , 4.7958803 , ..., -0.4642661 ,\n", |
|
|
|
|
|
" -0.04427804, -0.12225106],\n", |
|
|
|
|
|
" [-0.35347146, 8.824585 , 9.5098095 , ..., -1.2693769 ,\n", |
|
|
|
|
|
" -0.97772634, -0.13521233],\n", |
|
|
|
|
|
" [-0.39733842, 5.693817 , 6.368334 , ..., -0.26423275,\n", |
|
|
|
|
|
" -1.527182 , 0.16518843]], dtype=float32),\n", |
|
|
|
|
|
" array([[-0.53349733, 4.6777644 , 4.7953978 , ..., -0.4346264 ,\n", |
|
|
|
|
|
" -0.03433151, -0.11583059],\n", |
|
|
|
|
|
" [-0.3695631 , 8.82613 , 9.477964 , ..., -1.2404828 ,\n", |
|
|
|
|
|
" -0.9594678 , -0.11550146],\n", |
|
|
|
|
|
" [-0.41203997, 5.699034 , 6.341655 , ..., -0.24370295,\n", |
|
|
|
|
|
" -1.5179048 , 0.16230991]], dtype=float32)]" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": 284, |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"output_type": "execute_result" |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"result_np" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 285, |
|
|
|
|
|
"id": "679e2316", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def predict(input_line, n_predictions=1):\n", |
|
|
|
|
|
" print('\\n> %s' % input_line)\n", |
|
|
|
|
|
" with torch.no_grad():\n", |
|
|
|
|
|
" output = best_model(input_line.to(device), src_mask)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" # Get top N categories\n", |
|
|
|
|
|
" topv, topi = output.topk(n_predictions, 1, True)\n", |
|
|
|
|
|
" #x, y = output.topk(n_predictions, 1, True)\n", |
|
|
|
|
|
" #print(x.shape)\n", |
|
|
|
|
|
" #print(topv.shape)\n", |
|
|
|
|
|
" # print(topi.shape)\n", |
|
|
|
|
|
" predictions = []\n", |
|
|
|
|
|
" for i in range(n_predictions):\n", |
|
|
|
|
|
" value = topv[0][i]\n", |
|
|
|
|
|
" v1, v2 = value.topk(1)\n", |
|
|
|
|
|
" predict_token_index = v2.cpu().detach().numpy()\n", |
|
|
|
|
|
" print(vocab.lookup_token(predict_token_index))\n", |
|
|
|
|
|
" #print(category_index)\n", |
|
|
|
|
|
" #print('(%.2f) %s' % (value, all_categories[category_index]))\n", |
|
|
|
|
|
" #predictions.append([value, all_categories[category_index]])" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 286, |
|
|
|
|
|
"id": "03389137", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"\n", |
|
|
|
|
|
"> tensor([ 3, 555, 16])\n", |
|
|
|
|
|
"tumors\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"> tensor([ 3, 76, 16])\n", |
|
|
|
|
|
"cancer\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"for d in data:\n", |
|
|
|
|
|
" predict(d)" |
|
|
] |
|
|
] |
|
|
} |
|
|
} |
|
|
], |
|
|
], |
|
|