Browse Source

add paper and remove unnecessary modules

dev_neuralnet
Leonard Starke 2 years ago
parent
commit
b80c921281
  1. 255
      AutomaticSentenceCompletion.ipynb

255
AutomaticSentenceCompletion.ipynb

@ -20,31 +20,37 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "ee9c1d92",
"id": "806cfb27",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Firstly try to import the data modules"
"### link to \"Attention is All You Need\" paper describing transformer models"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1,
"id": "e444b44c",
"execution_count": null,
"id": "fe862072",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"try:\n",
" from Bio import Entrez, Medline \n",
"except:\n",
" !pip install Bio\n",
" from Bio import Entrez, Medline \n"
"https://arxiv.org/pdf/1706.03762.pdf"
]
},
{
"cell_type": "markdown",
"id": "fa161b1b",
"metadata": {},
"source": [
"### load query data from text file "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"id": "e1912a79", "id": "e1912a79",
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -69,6 +75,28 @@
"!wget https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download -O pubmed-query.txt" "!wget https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download -O pubmed-query.txt"
] ]
}, },
{
"cell_type": "markdown",
"id": "da068411",
"metadata": {},
"source": [
"### import modules used for parsing query data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c10bc5a8",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" from Bio import Medline\n",
"except: \n",
" !pip install Bio\n",
" from Bio import Medline"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "7bf15c30", "id": "7bf15c30",
@ -115,7 +143,7 @@
"source": [ "source": [
"max_records = 150000\n", "max_records = 150000\n",
"records = getPapers(\"pubmed-query.txt\")\n", "records = getPapers(\"pubmed-query.txt\")\n",
"records = records[:min(max_records,len(records))]\n",
"records = records[:min(max_records, len(records))]\n",
"print(f\"Got {len(records)} records from the query text file\")" "print(f\"Got {len(records)} records from the query text file\")"
] ]
}, },
@ -1545,46 +1573,6 @@
"## Now we can try to predict based on trained model" "## Now we can try to predict based on trained model"
] ]
}, },
{
"cell_type": "markdown",
"id": "e685d3e1",
"metadata": {},
"source": [
"### define input batch "
]
},
{
"cell_type": "code",
"execution_count": 157,
"id": "cfb30fe0",
"metadata": {},
"outputs": [],
"source": [
"sample_batch = [\n",
" \"Hello World\"\n",
"]\n",
"input_batch = sample_batch"
]
},
{
"cell_type": "markdown",
"id": "10d51d39",
"metadata": {},
"source": [
"### define initial source mask for model"
]
},
{
"cell_type": "code",
"execution_count": 158,
"id": "305853e8",
"metadata": {},
"outputs": [],
"source": [
"bptt = 2\n",
"src_mask = generate_square_subsequent_mask(bptt).to(device)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "fe250072", "id": "fe250072",
@ -1635,10 +1623,22 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "6e2c35ba", "id": "6e2c35ba",
"metadata": {}, "metadata": {},
"outputs": [],
"outputs": [
{
"ename": "NameError",
"evalue": "name 'torch' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m device \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 2\u001b[0m device\n",
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
]
}
],
"source": [ "source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"device" "device"
@ -1735,6 +1735,58 @@
" return predictions" " return predictions"
] ]
}, },
{
"cell_type": "markdown",
"id": "a9b7311b",
"metadata": {},
"source": [
"### define input batch "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "913628b4",
"metadata": {},
"outputs": [],
"source": [
"sample_batch = [\n",
" \"There is\"\n",
"]\n",
"input_batch = sample_batch"
]
},
{
"cell_type": "markdown",
"id": "45930a71",
"metadata": {},
"source": [
"### define initial source mask for model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c4bba2a1",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'generate_square_subsequent_mask' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m bptt \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 2\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(bptt)\u001b[38;5;241m.\u001b[39mto(device)\n",
"\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined"
]
}
],
"source": [
"bptt = 2\n",
"src_mask = generate_square_subsequent_mask(bptt).to(device)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "5b33b9f3", "id": "5b33b9f3",
@ -1745,7 +1797,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 161,
"execution_count": 4,
"id": "b2895698", "id": "b2895698",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -1780,95 +1832,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "13ed9298", "id": "13ed9298",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream",
"text": [
"Possible continuations:\n",
"1 : health\n",
"2 : .\n",
"3 : ,\n",
"Hello World2\n",
"Text is now:\n",
"Hello World .\n",
"Possible continuations:\n",
"1 : the\n",
"2 : in\n",
"3 : this\n",
"Hello World .2\n",
"Text is now:\n",
"Hello World . in\n",
"Possible continuations:\n",
"1 : the\n",
"2 : a\n",
"3 : blood\n",
"Hello World . in1\n",
"Text is now:\n",
"Hello World . in the\n",
"Possible continuations:\n",
"1 : blood\n",
"2 : effect\n",
"3 : same\n",
"Hello World . in the1\n",
"Text is now:\n",
"Hello World . in the blood\n",
"Possible continuations:\n",
"1 : flow\n",
"2 : pressure\n",
"3 : vessels\n",
"Hello World . in the blood2\n",
"Text is now:\n",
"Hello World . in the blood pressure\n",
"Possible continuations:\n",
"1 : and\n",
"2 : (\n",
"3 : ,\n",
"Hello World . in the blood pressure1\n",
"Text is now:\n",
"Hello World . in the blood pressure and\n",
"Possible continuations:\n",
"1 : the\n",
"2 : in\n",
"3 : a\n",
"Hello World . in the blood pressure and1\n",
"Text is now:\n",
"Hello World . in the blood pressure and the\n",
"Possible continuations:\n",
"1 : blood\n",
"2 : effect\n",
"3 : same\n",
"Hello World . in the blood pressure and the1\n",
"Text is now:\n",
"Hello World . in the blood pressure and the blood\n",
"Possible continuations:\n",
"1 : flow\n",
"2 : pressure\n",
"3 : vessels\n",
"Hello World . in the blood pressure and the blood3\n",
"Text is now:\n",
"Hello World . in the blood pressure and the blood vessels\n",
"Possible continuations:\n",
"1 : .\n",
"2 : of\n",
"3 : ,\n",
"Hello World . in the blood pressure and the blood vessels2\n",
"Text is now:\n",
"Hello World . in the blood pressure and the blood vessels of\n",
"Possible continuations:\n",
"1 : the\n",
"2 : blood\n",
"3 : a\n",
"Hello World . in the blood pressure and the blood vessels of1\n",
"Text is now:\n",
"Hello World . in the blood pressure and the blood vessels of the\n",
"Possible continuations:\n",
"1 : blood\n",
"2 : effect\n",
"3 : same\n"
"ename": "NameError",
"evalue": "name 'generate_square_subsequent_mask' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mpredict_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn [4], line 7\u001b[0m, in \u001b[0;36mpredict_loop\u001b[0;34m(num_of_pred)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m(\u001b[38;5;129;01mnot\u001b[39;00m is_terminated):\n\u001b[1;32m 6\u001b[0m mask_size \u001b[38;5;241m=\u001b[39m bptt\u001b[38;5;241m+\u001b[39m(iteration) \n\u001b[0;32m----> 7\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(mask_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m toDataTensor(input_batch)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, d \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(data):\n",
"\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined"
] ]
} }
], ],

Loading…
Cancel
Save