Browse Source

store complete model\nnot just the parameters but the entire instance of the model is saved to disk for lower code complexity when loading

dev_neuralnet
Constantin Fürst 2 years ago
parent
commit
783ec89423
  1. 110
      AutomaticSentenceCompletion.ipynb

110
AutomaticSentenceCompletion.ipynb

@ -44,7 +44,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"id": "3209935b",
"id": "293027a6",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -101,7 +101,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6,
"execution_count": 9,
"id": "00481ec9", "id": "00481ec9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -128,7 +128,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7,
"execution_count": 10,
"id": "dcf5c217", "id": "dcf5c217",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -149,10 +149,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30,
"execution_count": 11,
"id": "c3199444", "id": "c3199444",
"metadata": {}, "metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hein/.local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [ "source": [
"try:\n", "try:\n",
" import torch\n", " import torch\n",
@ -176,7 +185,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31,
"execution_count": 12,
"id": "daca9db6", "id": "daca9db6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -198,7 +207,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33,
"execution_count": 13,
"id": "8d2312db", "id": "8d2312db",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -216,7 +225,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34,
"execution_count": 14,
"id": "3f23404d", "id": "3f23404d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -228,7 +237,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35,
"execution_count": 15,
"id": "8a128d3c", "id": "8a128d3c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -240,7 +249,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36,
"execution_count": 16,
"id": "97e89986", "id": "97e89986",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -252,7 +261,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37,
"execution_count": 17,
"id": "0d6e89c4", "id": "0d6e89c4",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -272,7 +281,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38,
"execution_count": 18,
"id": "0bdbc40a", "id": "0bdbc40a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -293,7 +302,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39,
"execution_count": 19,
"id": "a438ab1f", "id": "a438ab1f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -314,7 +323,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40,
"execution_count": 20,
"id": "0e5bc361", "id": "0e5bc361",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -327,7 +336,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41,
"execution_count": 21,
"id": "dfd7400d", "id": "dfd7400d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -350,7 +359,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42,
"execution_count": 22,
"id": "c155ee31", "id": "c155ee31",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -389,7 +398,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 46,
"execution_count": 23,
"id": "a33d722f", "id": "a33d722f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -453,7 +462,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 47,
"execution_count": 24,
"id": "c2f6d33b", "id": "c2f6d33b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -490,7 +499,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 48,
"execution_count": 25,
"id": "9e184841", "id": "9e184841",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -514,7 +523,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 49,
"execution_count": 26,
"id": "a4def1ac", "id": "a4def1ac",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -536,7 +545,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 50,
"execution_count": 27,
"id": "4ab5b8fd", "id": "4ab5b8fd",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -568,7 +577,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 51,
"execution_count": 28,
"id": "c53764da", "id": "c53764da",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -592,7 +601,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52,
"execution_count": 29,
"id": "ddaa1d64", "id": "ddaa1d64",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -616,7 +625,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54,
"execution_count": 30,
"id": "50ab3fb6", "id": "50ab3fb6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -665,7 +674,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 55,
"execution_count": 31,
"id": "3d179bb0", "id": "3d179bb0",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -1736,7 +1745,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"torch.save(best_model.state_dict(), \"autocomplete_model\")"
"torch.save(best_model, \"pubmed-sentencecomplete.pt\")"
] ]
}, },
{ {
@ -1757,14 +1766,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 59,
"execution_count": 32,
"id": "cfb30fe0", "id": "cfb30fe0",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_batch = [\n", "sample_batch = [\n",
" \"The brain is\",\n", " \"The brain is\",\n",
" \"The lung is\"\n",
"]\n", "]\n",
"input_batch = sample_batch" "input_batch = sample_batch"
] ]
@ -1779,7 +1787,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 60,
"execution_count": 33,
"id": "305853e8", "id": "305853e8",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -1793,18 +1801,18 @@
"id": "fe250072", "id": "fe250072",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### define iterator for predict batch "
"### obtain iterator for predict batch "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 61,
"execution_count": 34,
"id": "afe585d6", "id": "afe585d6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def predict_abstract_iter():\n",
" for batch in sample_batch:\n",
"def predict_abstract_iter(batch):\n",
" for batch in batch:\n",
" yield tokenizer(batch)" " yield tokenizer(batch)"
] ]
}, },
@ -1818,13 +1826,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 62,
"execution_count": 35,
"id": "8bfaa8bd", "id": "8bfaa8bd",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def toDataTensor():\n",
" predict_generator = predict_abstract_iter()\n",
"def toDataTensor(batch):\n",
" predict_generator = predict_abstract_iter(batch)\n",
" return [torch.tensor(vocab.lookup_indices(item)) for item in predict_generator]" " return [torch.tensor(vocab.lookup_indices(item)) for item in predict_generator]"
] ]
}, },
@ -1838,7 +1846,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 63,
"execution_count": 36,
"id": "6e2c35ba", "id": "6e2c35ba",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1848,7 +1856,7 @@
"device(type='cuda')" "device(type='cuda')"
] ]
}, },
"execution_count": 63,
"execution_count": 36,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -1868,23 +1876,26 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 68,
"execution_count": 40,
"id": "223eed8a", "id": "223eed8a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
"ename": "RuntimeError",
"evalue": "Error(s) in loading state_dict for TransformerModel:\n\tsize mismatch for encoder.weight: copying a param with shape torch.Size([84399, 200]) from checkpoint, the shape in current model is torch.Size([6526, 200]).\n\tsize mismatch for decoder.weight: copying a param with shape torch.Size([84399, 200]) from checkpoint, the shape in current model is torch.Size([6526, 200]).\n\tsize mismatch for decoder.bias: copying a param with shape torch.Size([84399]) from checkpoint, the shape in current model is torch.Size([6526]).",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [40], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m best_model \u001b[38;5;241m=\u001b[39m TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mbest_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mautocomplete_model\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1667\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m 1662\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m 1663\u001b[0m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1664\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(k) \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[1;32m 1666\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 1667\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1668\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 1669\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n",
"\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for TransformerModel:\n\tsize mismatch for encoder.weight: copying a param with shape torch.Size([84399, 200]) from checkpoint, the shape in current model is torch.Size([6526, 200]).\n\tsize mismatch for decoder.weight: copying a param with shape torch.Size([84399, 200]) from checkpoint, the shape in current model is torch.Size([6526, 200]).\n\tsize mismatch for decoder.bias: copying a param with shape torch.Size([84399]) from checkpoint, the shape in current model is torch.Size([6526])."
]
} }
], ],
"source": [ "source": [
"best_model.load_state_dict(torch.load(\"autocomplete_model\"))"
"best_model = torch.load(\"pubmed-sentencecomplete.pt\")\n",
"best_model.eval()"
] ]
}, },
{ {
@ -1943,7 +1954,8 @@
" # 2*count is need because spaces count aswell\n", " # 2*count is need because spaces count aswell\n",
" mask_size = bptt+(iteration) \n", " mask_size = bptt+(iteration) \n",
" src_mask = generate_square_subsequent_mask(mask_size).to(device)\n", " src_mask = generate_square_subsequent_mask(mask_size).to(device)\n",
" data = toDataTensor()\n",
" data = toDataTensor(input_batch)\n",
" \n",
" for i, d in enumerate(data):\n", " for i, d in enumerate(data):\n",
" predictions = predict(d, src_mask, num_of_pred)\n", " predictions = predict(d, src_mask, num_of_pred)\n",
" print(\"Current input:\", i)\n", " print(\"Current input:\", i)\n",

Loading…
Cancel
Save