From b80c9212812543c1f07779990d273c86ac3d7af4 Mon Sep 17 00:00:00 2001 From: Leonard Starke Date: Fri, 27 Jan 2023 23:17:00 +0100 Subject: [PATCH] add paper and remove unnecessary modules --- AutomaticSentenceCompletion.ipynb | 255 ++++++++++++++---------------- 1 file changed, 116 insertions(+), 139 deletions(-) diff --git a/AutomaticSentenceCompletion.ipynb b/AutomaticSentenceCompletion.ipynb index 852c59a..b23d5e2 100644 --- a/AutomaticSentenceCompletion.ipynb +++ b/AutomaticSentenceCompletion.ipynb @@ -20,31 +20,37 @@ }, { "cell_type": "markdown", - "id": "ee9c1d92", + "id": "806cfb27", "metadata": {}, "source": [ - "### Firstly try to import the data modules" + "### link to \"Attention is All You Need\" paper describing transformer models" ] }, { "cell_type": "code", - "execution_count": 1, - "id": "e444b44c", + "execution_count": null, + "id": "fe862072", "metadata": {}, "outputs": [], "source": [ - "try:\n", - " from Bio import Entrez, Medline \n", - "except:\n", - " !pip install Bio\n", - " from Bio import Entrez, Medline \n" + "https://arxiv.org/pdf/1706.03762.pdf" + ] + }, + { + "cell_type": "markdown", + "id": "fa161b1b", + "metadata": {}, + "source": [ + "### load query data from text file " ] }, { "cell_type": "code", "execution_count": 2, "id": "e1912a79", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { "name": "stdout", @@ -69,6 +75,28 @@ "!wget https://cloud.constantin-fuerst.com/s/944x5BpTQM7GjtF/download -O pubmed-query.txt" ] }, + { + "cell_type": "markdown", + "id": "da068411", + "metadata": {}, + "source": [ + "### import modules used for parsing query data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c10bc5a8", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from Bio import Medline\n", + "except: \n", + " !pip install Bio\n", + " from Bio import Medline" + ] + }, { "cell_type": "markdown", "id": "7bf15c30", @@ -115,7 +143,7 @@ "source": [ "max_records = 150000\n", "records = getPapers(\"pubmed-query.txt\")\n", - "records = records[:min(max_records,len(records))]\n", + "records = records[:min(max_records, len(records))]\n", "print(f\"Got {len(records)} records from the query text file\")" ] }, @@ -1545,46 +1573,6 @@ "## Now we can try to predict based on trained model" ] }, - { - "cell_type": "markdown", - "id": "e685d3e1", - "metadata": {}, - "source": [ - "### define input batch " - ] - }, - { - "cell_type": "code", - "execution_count": 157, - "id": "cfb30fe0", - "metadata": {}, - "outputs": [], - "source": [ - "sample_batch = [\n", - " \"Hello World\"\n", - "]\n", - "input_batch = sample_batch" - ] - }, - { - "cell_type": "markdown", - "id": "10d51d39", - "metadata": {}, - "source": [ - "### define initial source mask for model" - ] - }, - { - "cell_type": "code", - "execution_count": 158, - "id": "305853e8", - "metadata": {}, - "outputs": [], - "source": [ - "bptt = 2\n", - "src_mask = generate_square_subsequent_mask(bptt).to(device)" - ] - }, { "cell_type": "markdown", "id": "fe250072", @@ -1635,10 +1623,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "6e2c35ba", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'torch' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn [7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m device \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 2\u001b[0m device\n", + "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined" + ] + } + ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" @@ -1735,6 +1735,58 @@ " return predictions" ] }, + { + "cell_type": "markdown", + "id": "a9b7311b", + "metadata": {}, + "source": [ + "### define input batch " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "913628b4", + "metadata": {}, + "outputs": [], + "source": [ + "sample_batch = [\n", + " \"There is\"\n", + "]\n", + "input_batch = sample_batch" + ] + }, + { + "cell_type": "markdown", + "id": "45930a71", + "metadata": {}, + "source": [ + "### define initial source mask for model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c4bba2a1", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'generate_square_subsequent_mask' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn [3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m bptt \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 2\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(bptt)\u001b[38;5;241m.\u001b[39mto(device)\n", + "\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined" + ] + } + ], + "source": [ + "bptt = 2\n", + "src_mask = generate_square_subsequent_mask(bptt).to(device)" + ] + }, { "cell_type": "markdown", "id": "5b33b9f3", @@ -1745,7 +1797,7 @@ }, { "cell_type": "code", - "execution_count": 161, + "execution_count": 4, "id": "b2895698", "metadata": {}, "outputs": [], @@ -1780,95 +1832,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "13ed9298", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Possible continuations:\n", - "1 : health\n", - "2 : .\n", - "3 : ,\n", - "Hello World2\n", - "Text is now:\n", - "Hello World .\n", - "Possible continuations:\n", - "1 : the\n", - "2 : in\n", - "3 : this\n", - "Hello World .2\n", - "Text is now:\n", - "Hello World . in\n", - "Possible continuations:\n", - "1 : the\n", - "2 : a\n", - "3 : blood\n", - "Hello World . in1\n", - "Text is now:\n", - "Hello World . in the\n", - "Possible continuations:\n", - "1 : blood\n", - "2 : effect\n", - "3 : same\n", - "Hello World . in the1\n", - "Text is now:\n", - "Hello World . in the blood\n", - "Possible continuations:\n", - "1 : flow\n", - "2 : pressure\n", - "3 : vessels\n", - "Hello World . in the blood2\n", - "Text is now:\n", - "Hello World . in the blood pressure\n", - "Possible continuations:\n", - "1 : and\n", - "2 : (\n", - "3 : ,\n", - "Hello World . in the blood pressure1\n", - "Text is now:\n", - "Hello World . in the blood pressure and\n", - "Possible continuations:\n", - "1 : the\n", - "2 : in\n", - "3 : a\n", - "Hello World . in the blood pressure and1\n", - "Text is now:\n", - "Hello World . in the blood pressure and the\n", - "Possible continuations:\n", - "1 : blood\n", - "2 : effect\n", - "3 : same\n", - "Hello World . in the blood pressure and the1\n", - "Text is now:\n", - "Hello World . in the blood pressure and the blood\n", - "Possible continuations:\n", - "1 : flow\n", - "2 : pressure\n", - "3 : vessels\n", - "Hello World . in the blood pressure and the blood3\n", - "Text is now:\n", - "Hello World . in the blood pressure and the blood vessels\n", - "Possible continuations:\n", - "1 : .\n", - "2 : of\n", - "3 : ,\n", - "Hello World . in the blood pressure and the blood vessels2\n", - "Text is now:\n", - "Hello World . in the blood pressure and the blood vessels of\n", - "Possible continuations:\n", - "1 : the\n", - "2 : blood\n", - "3 : a\n", - "Hello World . in the blood pressure and the blood vessels of1\n", - "Text is now:\n", - "Hello World . in the blood pressure and the blood vessels of the\n", - "Possible continuations:\n", - "1 : blood\n", - "2 : effect\n", - "3 : same\n" + "ename": "NameError", + "evalue": "name 'generate_square_subsequent_mask' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn [5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mpredict_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn [4], line 7\u001b[0m, in \u001b[0;36mpredict_loop\u001b[0;34m(num_of_pred)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m(\u001b[38;5;129;01mnot\u001b[39;00m is_terminated):\n\u001b[1;32m 6\u001b[0m mask_size \u001b[38;5;241m=\u001b[39m bptt\u001b[38;5;241m+\u001b[39m(iteration) \n\u001b[0;32m----> 7\u001b[0m src_mask \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_square_subsequent_mask\u001b[49m(mask_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m toDataTensor(input_batch)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, d \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(data):\n", + "\u001b[0;31mNameError\u001b[0m: name 'generate_square_subsequent_mask' is not defined" ] } ],