|
|
@ -20,31 +20,37 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "ee9c1d92", |
|
|
|
"id": "806cfb27", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### Firstly try to import the data modules" |
|
|
|
"### link to \"Attention is All You Need\" paper describing transformer models" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 1, |
|
|
|
"id": "e444b44c", |
|
|
|
"execution_count": null, |
|
|
|
"id": "fe862072", |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"try:\n", |
|
|
|
" from Bio import Entrez, Medline \n", |
|
|
|
"except:\n", |
|
|
|
" !pip install Bio\n", |
|
|
|
" from Bio import Entrez, Medline \n" |
|
|
|
"https://arxiv.org/pdf/1706.03762.pdf" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "fa161b1b", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### load query data from text file " |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 2, |
|
|
|
"id": "e1912a79", |
|
|
|
"metadata": {}, |
|
|
|
"metadata": { |
|
|
|
"scrolled": false |
|
|
|
}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
@ -69,6 +75,28 @@ |
|
|
|
"!wget https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download -O pubmed-query.txt" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "da068411", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### import modules used for parsing query data" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"id": "c10bc5a8", |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"try:\n", |
|
|
|
" from Bio import Medline\n", |
|
|
|
"except: \n", |
|
|
|
" !pip install Bio\n", |
|
|
|
" from Bio import Medline" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "7bf15c30", |
|
|
@ -1545,46 +1573,6 @@ |
|
|
|
"## Now we can try to predict based on trained model" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "e685d3e1", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### define input batch " |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 157, |
|
|
|
"id": "cfb30fe0", |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"sample_batch = [\n", |
|
|
|
" \"Hello World\"\n", |
|
|
|
"]\n", |
|
|
|
"input_batch = sample_batch" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "10d51d39", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### define initial source mask for model" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 158, |
|
|
|
"id": "305853e8", |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"bptt = 2\n", |
|
|
|
"src_mask = generate_square_subsequent_mask(bptt).to(device)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "fe250072", |
|
|
@ -1635,10 +1623,22 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"execution_count": 7, |
|
|
|
"id": "6e2c35ba", |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"ename": "NameError", |
|
|
|
"evalue": "name 'torch' is not defined", |
|
|
|
"output_type": "error", |
|
|
|
"traceback": [ |
|
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
|
|
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", |
|
|
|
"Cell \u001b[0;32mIn [7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m device \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 2\u001b[0m device\n", |
|
|
|
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", |
|
|
|
"device" |
|
|
@ -1735,6 +1735,58 @@ |
|
|
|
" return predictions" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "a9b7311b", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### define input batch " |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 2, |
|
|
|
"id": "913628b4", |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"sample_batch = [\n", |
|
|
|
" \"There is\"\n", |
|
|
|
"]\n", |
|
|
|
"input_batch = sample_batch" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "45930a71", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### define initial source mask for model" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 3, |
|
|
|
"id": "c4bba2a1", |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"ename": "NameError", |
|
|
|
"evalue": "name 'generate_square_subsequent_mask' is not defined", |
|
|
|
"output_type": "error", |
|
|
|
"traceback": [ |
|
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
|
|
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", |
|
|
|
"Cell \u001b[0;32mIn [3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m bptt \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 2\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(bptt)\u001b[38;5;241m.\u001b[39mto(device)\n", |
|
|
|
"\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"bptt = 2\n", |
|
|
|
"src_mask = generate_square_subsequent_mask(bptt).to(device)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"id": "5b33b9f3", |
|
|
@ -1745,7 +1797,7 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 161, |
|
|
|
"execution_count": 4, |
|
|
|
"id": "b2895698", |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
@ -1780,95 +1832,20 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"execution_count": 5, |
|
|
|
"id": "13ed9298", |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : health\n", |
|
|
|
"2 : .\n", |
|
|
|
"3 : ,\n", |
|
|
|
"Hello World2\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World .\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : the\n", |
|
|
|
"2 : in\n", |
|
|
|
"3 : this\n", |
|
|
|
"Hello World .2\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : the\n", |
|
|
|
"2 : a\n", |
|
|
|
"3 : blood\n", |
|
|
|
"Hello World . in1\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in the\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : blood\n", |
|
|
|
"2 : effect\n", |
|
|
|
"3 : same\n", |
|
|
|
"Hello World . in the1\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in the blood\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : flow\n", |
|
|
|
"2 : pressure\n", |
|
|
|
"3 : vessels\n", |
|
|
|
"Hello World . in the blood2\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in the blood pressure\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : and\n", |
|
|
|
"2 : (\n", |
|
|
|
"3 : ,\n", |
|
|
|
"Hello World . in the blood pressure1\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in the blood pressure and\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : the\n", |
|
|
|
"2 : in\n", |
|
|
|
"3 : a\n", |
|
|
|
"Hello World . in the blood pressure and1\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in the blood pressure and the\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : blood\n", |
|
|
|
"2 : effect\n", |
|
|
|
"3 : same\n", |
|
|
|
"Hello World . in the blood pressure and the1\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in the blood pressure and the blood\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : flow\n", |
|
|
|
"2 : pressure\n", |
|
|
|
"3 : vessels\n", |
|
|
|
"Hello World . in the blood pressure and the blood3\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in the blood pressure and the blood vessels\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : .\n", |
|
|
|
"2 : of\n", |
|
|
|
"3 : ,\n", |
|
|
|
"Hello World . in the blood pressure and the blood vessels2\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in the blood pressure and the blood vessels of\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : the\n", |
|
|
|
"2 : blood\n", |
|
|
|
"3 : a\n", |
|
|
|
"Hello World . in the blood pressure and the blood vessels of1\n", |
|
|
|
"Text is now:\n", |
|
|
|
"Hello World . in the blood pressure and the blood vessels of the\n", |
|
|
|
"Possible continuations:\n", |
|
|
|
"1 : blood\n", |
|
|
|
"2 : effect\n", |
|
|
|
"3 : same\n" |
|
|
|
"ename": "NameError", |
|
|
|
"evalue": "name 'generate_square_subsequent_mask' is not defined", |
|
|
|
"output_type": "error", |
|
|
|
"traceback": [ |
|
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
|
|
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", |
|
|
|
"Cell \u001b[0;32mIn [5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mpredict_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n", |
|
|
|
"Cell \u001b[0;32mIn [4], line 7\u001b[0m, in \u001b[0;36mpredict_loop\u001b[0;34m(num_of_pred)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m(\u001b[38;5;129;01mnot\u001b[39;00m is_terminated):\n\u001b[1;32m 6\u001b[0m mask_size \u001b[38;5;241m=\u001b[39m bptt\u001b[38;5;241m+\u001b[39m(iteration) \n\u001b[0;32m----> 7\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(mask_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m toDataTensor(input_batch)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, d \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(data):\n", |
|
|
|
"\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|