diff --git a/AutomaticSentenceCompletion.ipynb b/AutomaticSentenceCompletion.ipynb index 1efb201..0e977f5 100644 --- a/AutomaticSentenceCompletion.ipynb +++ b/AutomaticSentenceCompletion.ipynb @@ -44,7 +44,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "3209935b", + "id": "293027a6", "metadata": {}, "outputs": [ { @@ -101,7 +101,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "id": "00481ec9", "metadata": {}, "outputs": [ @@ -128,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "id": "dcf5c217", "metadata": {}, "outputs": [], @@ -149,10 +149,19 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 11, "id": "c3199444", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/hein/.local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "try:\n", " import torch\n", @@ -176,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 12, "id": "daca9db6", "metadata": {}, "outputs": [], @@ -198,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 13, "id": "8d2312db", "metadata": {}, "outputs": [], @@ -216,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 14, "id": "3f23404d", "metadata": {}, "outputs": [], @@ -228,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 15, "id": "8a128d3c", "metadata": {}, "outputs": [], @@ -240,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 16, "id": "97e89986", "metadata": {}, "outputs": [], @@ -252,7 +261,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 17, "id": "0d6e89c4", "metadata": {}, "outputs": [], @@ -272,7 +281,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 18, "id": "0bdbc40a", "metadata": {}, "outputs": [], @@ -293,7 +302,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 19, "id": "a438ab1f", "metadata": {}, "outputs": [], @@ -314,7 +323,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 20, "id": "0e5bc361", "metadata": {}, "outputs": [], @@ -327,7 +336,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 21, "id": "dfd7400d", "metadata": {}, "outputs": [], @@ -350,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 22, "id": "c155ee31", "metadata": {}, "outputs": [], @@ -389,7 +398,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 23, "id": "a33d722f", "metadata": {}, "outputs": [], @@ -453,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 24, "id": "c2f6d33b", "metadata": {}, "outputs": [], @@ -490,7 +499,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 25, "id": "9e184841", "metadata": {}, "outputs": [], @@ -514,7 +523,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 26, "id": "a4def1ac", "metadata": {}, "outputs": [], @@ -536,7 +545,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 27, "id": "4ab5b8fd", "metadata": {}, "outputs": [], @@ -568,7 +577,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 28, "id": "c53764da", "metadata": {}, "outputs": [], @@ -592,7 +601,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 29, "id": "ddaa1d64", "metadata": {}, "outputs": [], @@ -616,7 +625,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 30, "id": "50ab3fb6", "metadata": {}, "outputs": [], @@ -665,7 +674,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 31, "id": "3d179bb0", "metadata": {}, "outputs": [], @@ -1736,7 +1745,7 @@ "metadata": {}, "outputs": [], "source": [ - "torch.save(best_model.state_dict(), \"autocomplete_model\")" + "torch.save(best_model, \"pubmed-sentencecomplete.pt\")" ] }, { @@ -1757,14 +1766,13 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 32, "id": "cfb30fe0", "metadata": {}, "outputs": [], "source": [ "sample_batch = [\n", " \"The brain is\",\n", - " \"The lung is\"\n", "]\n", "input_batch = sample_batch" ] @@ -1779,7 +1787,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 33, "id": "305853e8", "metadata": {}, "outputs": [], @@ -1793,18 +1801,18 @@ "id": "fe250072", "metadata": {}, "source": [ - "### define iterator for predict batch " + "### obtain iterator for predict batch " ] }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 34, "id": "afe585d6", "metadata": {}, "outputs": [], "source": [ - "def predict_abstract_iter():\n", - " for batch in sample_batch:\n", + "def predict_abstract_iter(batch):\n", + " for batch in batch:\n", " yield tokenizer(batch)" ] }, @@ -1818,13 +1826,13 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 35, "id": "8bfaa8bd", "metadata": {}, "outputs": [], "source": [ - "def toDataTensor():\n", - " predict_generator = predict_abstract_iter()\n", + "def toDataTensor(batch):\n", + " predict_generator = predict_abstract_iter(batch)\n", " return [torch.tensor(vocab.lookup_indices(item)) for item in predict_generator]" ] }, @@ -1838,7 +1846,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 36, "id": "6e2c35ba", "metadata": {}, "outputs": [ @@ -1848,7 +1856,7 @@ "device(type='cuda')" ] }, - "execution_count": 63, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -1868,23 +1876,26 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 40, "id": "223eed8a", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 68, - "metadata": {}, - "output_type": "execute_result" + "ename": "RuntimeError", + "evalue": "Error(s) in loading state_dict for TransformerModel:\n\tsize mismatch for encoder.weight: copying a param with shape torch.Size([84399, 200]) from checkpoint, the shape in current model is torch.Size([6526, 200]).\n\tsize mismatch for decoder.weight: copying a param with shape torch.Size([84399, 200]) from checkpoint, the shape in current model is torch.Size([6526, 200]).\n\tsize mismatch for decoder.bias: copying a param with shape torch.Size([84399]) from checkpoint, the shape in current model is torch.Size([6526]).", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn [40], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m best_model \u001b[38;5;241m=\u001b[39m TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mbest_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mautocomplete_model\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1667\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m 1662\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m 1663\u001b[0m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1664\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(k) \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[1;32m 1666\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 1667\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1668\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 1669\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for TransformerModel:\n\tsize mismatch for encoder.weight: copying a param with shape torch.Size([84399, 200]) from checkpoint, the shape in current model is torch.Size([6526, 200]).\n\tsize mismatch for decoder.weight: copying a param with shape torch.Size([84399, 200]) from checkpoint, the shape in current model is torch.Size([6526, 200]).\n\tsize mismatch for decoder.bias: copying a param with shape torch.Size([84399]) from checkpoint, the shape in current model is torch.Size([6526])." + ] } ], "source": [ - "best_model.load_state_dict(torch.load(\"autocomplete_model\"))" + "best_model = torch.load(\"pubmed-sentencecomplete.pt\")\n", + "best_model.eval()" ] }, { @@ -1943,7 +1954,8 @@ " # 2*count is need because spaces count aswell\n", " mask_size = bptt+(iteration) \n", " src_mask = generate_square_subsequent_mask(mask_size).to(device)\n", - " data = toDataTensor()\n", + " data = toDataTensor(input_batch)\n", + " \n", " for i, d in enumerate(data):\n", " predictions = predict(d, src_mask, num_of_pred)\n", " print(\"Current input:\", i)\n",