Browse Source

fix prediction and add doc

dev_neuralnet
Leonard Starke 2 years ago
parent
commit
4e62562df3
  1. 290
      AutomaticSentenceCompletion.ipynb

290
AutomaticSentenceCompletion.ipynb

@ -435,12 +435,12 @@
] ]
}, },
{ {
"cell_type": "code",
"execution_count": null,
"id": "da8fb12b",
"cell_type": "markdown",
"id": "3b78cc08",
"metadata": {}, "metadata": {},
"outputs": [],
"source": []
"source": [
"### define pos encoder"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
@ -471,6 +471,14 @@
" return self.dropout(x)\n" " return self.dropout(x)\n"
] ]
}, },
{
"cell_type": "markdown",
"id": "0adefcce",
"metadata": {},
"source": [
"### define function to create batches of data and create batches"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 48, "execution_count": 48,
@ -509,6 +517,14 @@
"test_data = batchify(test_data, eval_batch_size)" "test_data = batchify(test_data, eval_batch_size)"
] ]
}, },
{
"cell_type": "markdown",
"id": "4f407ad0",
"metadata": {},
"source": [
"### define function to get batch"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 50, "execution_count": 50,
@ -533,6 +549,14 @@
" return data, target" " return data, target"
] ]
}, },
{
"cell_type": "markdown",
"id": "7ee28c38",
"metadata": {},
"source": [
"### define parameters and init model"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 51, "execution_count": 51,
@ -549,10 +573,18 @@
"model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)" "model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)"
] ]
}, },
{
"cell_type": "markdown",
"id": "51f2400a",
"metadata": {},
"source": [
"### init optimizer, loss, scheduler etc."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52,
"id": "50ab3fb6",
"execution_count": null,
"id": "b9a04e07",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -562,7 +594,24 @@
"criterion = nn.CrossEntropyLoss()\n", "criterion = nn.CrossEntropyLoss()\n",
"lr = 5.0 # learning rate\n", "lr = 5.0 # learning rate\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", "optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)"
]
},
{
"cell_type": "markdown",
"id": "07317af8",
"metadata": {},
"source": [
"### define train function"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "50ab3fb6",
"metadata": {},
"outputs": [],
"source": [
"\n", "\n",
"def train(model: nn.Module) -> None:\n", "def train(model: nn.Module) -> None:\n",
" model.train() # turn on train mode\n", " model.train() # turn on train mode\n",
@ -595,8 +644,24 @@
" f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n", " f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n",
" f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n", " f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n",
" total_loss = 0\n", " total_loss = 0\n",
" start_time = time.time()\n",
"\n",
" start_time = time.time()"
]
},
{
"cell_type": "markdown",
"id": "23709949",
"metadata": {},
"source": [
"### define evaluate function"
]
},
{
"cell_type": "code",
"execution_count": 289,
"id": "689bd4ea",
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n", "def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n",
" model.eval() # turn on evaluation mode\n", " model.eval() # turn on evaluation mode\n",
" total_loss = 0.\n", " total_loss = 0.\n",
@ -613,11 +678,21 @@
" return total_loss / (len(eval_data) - 1)" " return total_loss / (len(eval_data) - 1)"
] ]
}, },
{
"cell_type": "markdown",
"id": "d7c6a1e0",
"metadata": {},
"source": [
"### now we can start training the model while saving best one"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 53, "execution_count": 53,
"id": "09c4d4ce", "id": "09c4d4ce",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -689,11 +764,21 @@
" scheduler.step()" " scheduler.step()"
] ]
}, },
{
"cell_type": "markdown",
"id": "565b5aa4",
"metadata": {},
"source": [
"### print info about best model after training"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54, "execution_count": 54,
"id": "12fdd0aa", "id": "12fdd0aa",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -714,6 +799,14 @@
"print('=' * 89)" "print('=' * 89)"
] ]
}, },
{
"cell_type": "markdown",
"id": "12031065",
"metadata": {},
"source": [
"## Now we can try to predict based on trained model"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "e685d3e1", "id": "e685d3e1",
@ -724,7 +817,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 274,
"execution_count": 300,
"id": "cfb30fe0", "id": "cfb30fe0",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -735,9 +828,17 @@
"]" "]"
] ]
}, },
{
"cell_type": "markdown",
"id": "054ada71",
"metadata": {},
"source": [
"### define source mask for model"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 275,
"execution_count": 301,
"id": "305853e8", "id": "305853e8",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -746,9 +847,17 @@
"src_mask = generate_square_subsequent_mask(bptt).to(device)" "src_mask = generate_square_subsequent_mask(bptt).to(device)"
] ]
}, },
{
"cell_type": "markdown",
"id": "4635a73e",
"metadata": {},
"source": [
"### define iterator for predict batch and init to generator"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 276,
"execution_count": 302,
"id": "afe585d6", "id": "afe585d6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -760,16 +869,16 @@
] ]
}, },
{ {
"cell_type": "code",
"execution_count": null,
"id": "f7ac6188",
"cell_type": "markdown",
"id": "1c171c8c",
"metadata": {}, "metadata": {},
"outputs": [],
"source": []
"source": [
"### load data into tensor for model to process"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 278,
"execution_count": 303,
"id": "0788b045", "id": "0788b045",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -779,7 +888,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 279,
"execution_count": 308,
"id": "8bfaa8bd", "id": "8bfaa8bd",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -789,7 +898,7 @@
"[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]" "[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]"
] ]
}, },
"execution_count": 279,
"execution_count": 308,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -798,10 +907,18 @@
"data" "data"
] ]
}, },
{
"cell_type": "markdown",
"id": "99132b3d",
"metadata": {},
"source": [
"### check device once again (prob not needed)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 280,
"id": "dd0e7310",
"execution_count": 309,
"id": "b8c50c8c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -810,7 +927,7 @@
"device(type='cuda')" "device(type='cuda')"
] ]
}, },
"execution_count": 280,
"execution_count": 309,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -821,87 +938,21 @@
] ]
}, },
{ {
"cell_type": "code",
"execution_count": 281,
"id": "1728f0fd",
"metadata": {},
"outputs": [],
"source": [
"for d in data:\n",
" d.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 282,
"id": "49d27864",
"metadata": {},
"outputs": [],
"source": [
"best_model.eval()\n",
"for batch in data:\n",
" output = best_model(batch.to(device), src_mask)"
]
},
{
"cell_type": "code",
"execution_count": 283,
"id": "a3404169",
"metadata": {},
"outputs": [],
"source": [
"result_np = []\n",
"pred_np = output.cpu().detach().numpy()\n",
"for el in pred_np:\n",
" result_np.append(el)"
]
},
{
"cell_type": "code",
"execution_count": 284,
"id": "c7064c0c",
"cell_type": "markdown",
"id": "05766f6b",
"metadata": {}, "metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([[-0.5258564 , 4.6108465 , 4.8358154 , ..., -0.46871045,\n",
" -0.04386039, -0.13068362],\n",
" [-0.35341978, 8.821181 , 9.57295 , ..., -1.2820313 ,\n",
" -0.989242 , -0.15542248],\n",
" [-0.39717233, 5.7111125 , 6.4295497 , ..., -0.27339834,\n",
" -1.5333815 , 0.16042188]], dtype=float32),\n",
" array([[-0.5291481 , 4.6452312 , 4.7958803 , ..., -0.4642661 ,\n",
" -0.04427804, -0.12225106],\n",
" [-0.35347146, 8.824585 , 9.5098095 , ..., -1.2693769 ,\n",
" -0.97772634, -0.13521233],\n",
" [-0.39733842, 5.693817 , 6.368334 , ..., -0.26423275,\n",
" -1.527182 , 0.16518843]], dtype=float32),\n",
" array([[-0.53349733, 4.6777644 , 4.7953978 , ..., -0.4346264 ,\n",
" -0.03433151, -0.11583059],\n",
" [-0.3695631 , 8.82613 , 9.477964 , ..., -1.2404828 ,\n",
" -0.9594678 , -0.11550146],\n",
" [-0.41203997, 5.699034 , 6.341655 , ..., -0.24370295,\n",
" -1.5179048 , 0.16230991]], dtype=float32)]"
]
},
"execution_count": 284,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"result_np"
"### define predict function"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 285,
"id": "679e2316",
"execution_count": 317,
"id": "0475bcc9",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def predict(input_line, n_predictions=1):\n",
"def predict(input_line, n_predictions=3):\n",
" print('\\n> %s' % input_line)\n", " print('\\n> %s' % input_line)\n",
" with torch.no_grad():\n", " with torch.no_grad():\n",
" output = best_model(input_line.to(device), src_mask)\n", " output = best_model(input_line.to(device), src_mask)\n",
@ -920,31 +971,64 @@
" print(vocab.lookup_token(predict_token_index))\n", " print(vocab.lookup_token(predict_token_index))\n",
" #print(category_index)\n", " #print(category_index)\n",
" #print('(%.2f) %s' % (value, all_categories[category_index]))\n", " #print('(%.2f) %s' % (value, all_categories[category_index]))\n",
" #predictions.append([value, all_categories[category_index]])"
" predictions.append(vocab.lookup_token(predict_token_index))\n",
" return predictions"
]
},
{
"cell_type": "markdown",
"id": "8ad2f64b",
"metadata": {},
"source": [
"### Execute prediction and display predicted values"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 286,
"id": "03389137",
"execution_count": 318,
"id": "55b73ea1",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]\n",
"\n", "\n",
"> tensor([ 3, 555, 16])\n", "> tensor([ 3, 555, 16])\n",
"tumors\n", "tumors\n",
"the\n",
"the\n",
"The brain is\n",
"Possible continuations:\n",
"0 : tumors\n",
"0 : the\n",
"0 : the\n",
"\n", "\n",
"> tensor([ 3, 76, 16])\n", "> tensor([ 3, 76, 16])\n",
"cancer\n"
"cancer\n",
"most\n",
"the\n",
"The lung is\n",
"Possible continuations:\n",
"0 : cancer\n",
"0 : most\n",
"0 : the\n"
] ]
} }
], ],
"source": [ "source": [
"print(data)\n",
"count = 0\n",
"num_of_pred = 3\n",
"for d in data:\n", "for d in data:\n",
" predict(d)"
" predictions = predict(d, num_of_pred)\n",
" print(input_batch[count])\n",
" print(\"Possible continuations:\")\n",
" for j in range(len(predictions)):\n",
" print(i, \": \", predictions[j])\n",
" count = count + 1\n",
" "
] ]
} }
], ],

Loading…
Cancel
Save