diff --git a/AutomaticSentenceCompletion.ipynb b/AutomaticSentenceCompletion.ipynb index 6724ae3..dd28b1a 100644 --- a/AutomaticSentenceCompletion.ipynb +++ b/AutomaticSentenceCompletion.ipynb @@ -435,12 +435,12 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "da8fb12b", + "cell_type": "markdown", + "id": "3b78cc08", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "### define pos encoder" + ] }, { "cell_type": "code", @@ -471,6 +471,14 @@ " return self.dropout(x)\n" ] }, + { + "cell_type": "markdown", + "id": "0adefcce", + "metadata": {}, + "source": [ + "### define function to create batches of data and create batches" + ] + }, { "cell_type": "code", "execution_count": 48, @@ -509,6 +517,14 @@ "test_data = batchify(test_data, eval_batch_size)" ] }, + { + "cell_type": "markdown", + "id": "4f407ad0", + "metadata": {}, + "source": [ + "### define function to get batch" + ] + }, { "cell_type": "code", "execution_count": 50, @@ -533,6 +549,14 @@ " return data, target" ] }, + { + "cell_type": "markdown", + "id": "7ee28c38", + "metadata": {}, + "source": [ + "### define parameters and init model" + ] + }, { "cell_type": "code", "execution_count": 51, @@ -549,10 +573,18 @@ "model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)" ] }, + { + "cell_type": "markdown", + "id": "51f2400a", + "metadata": {}, + "source": [ + "### init optimizer, loss, scheduler etc." + ] + }, { "cell_type": "code", - "execution_count": 52, - "id": "50ab3fb6", + "execution_count": null, + "id": "b9a04e07", "metadata": {}, "outputs": [], "source": [ @@ -562,7 +594,24 @@ "criterion = nn.CrossEntropyLoss()\n", "lr = 5.0 # learning rate\n", "optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", - "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)\n", + "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)" + ] + }, + { + "cell_type": "markdown", + "id": "07317af8", + "metadata": {}, + "source": [ + "### define train function" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "50ab3fb6", + "metadata": {}, + "outputs": [], + "source": [ "\n", "def train(model: nn.Module) -> None:\n", " model.train() # turn on train mode\n", @@ -595,8 +644,24 @@ " f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n", " f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n", " total_loss = 0\n", - " start_time = time.time()\n", - "\n", + " start_time = time.time()" + ] + }, + { + "cell_type": "markdown", + "id": "23709949", + "metadata": {}, + "source": [ + "### define evaluate function" + ] + }, + { + "cell_type": "code", + "execution_count": 289, + "id": "689bd4ea", + "metadata": {}, + "outputs": [], + "source": [ "def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n", " model.eval() # turn on evaluation mode\n", " total_loss = 0.\n", @@ -613,11 +678,21 @@ " return total_loss / (len(eval_data) - 1)" ] }, + { + "cell_type": "markdown", + "id": "d7c6a1e0", + "metadata": {}, + "source": [ + "### now we can start training the model while saving best one" + ] + }, { "cell_type": "code", "execution_count": 53, "id": "09c4d4ce", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", @@ -689,11 +764,21 @@ " scheduler.step()" ] }, + { + "cell_type": "markdown", + "id": "565b5aa4", + "metadata": {}, + "source": [ + "### print info about best model after training" + ] + }, { "cell_type": "code", "execution_count": 54, "id": "12fdd0aa", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", @@ -714,6 +799,14 @@ "print('=' * 89)" ] }, + { + "cell_type": "markdown", + "id": "12031065", + "metadata": {}, + "source": [ + "## Now we can try to predict based on trained model" + ] + }, { "cell_type": "markdown", "id": "e685d3e1", @@ -724,7 +817,7 @@ }, { "cell_type": "code", - "execution_count": 274, + "execution_count": 300, "id": "cfb30fe0", "metadata": {}, "outputs": [], @@ -735,9 +828,17 @@ "]" ] }, + { + "cell_type": "markdown", + "id": "054ada71", + "metadata": {}, + "source": [ + "### define source mask for model" + ] + }, { "cell_type": "code", - "execution_count": 275, + "execution_count": 301, "id": "305853e8", "metadata": {}, "outputs": [], @@ -746,9 +847,17 @@ "src_mask = generate_square_subsequent_mask(bptt).to(device)" ] }, + { + "cell_type": "markdown", + "id": "4635a73e", + "metadata": {}, + "source": [ + "### define iterator for predict batch and init to generator" + ] + }, { "cell_type": "code", - "execution_count": 276, + "execution_count": 302, "id": "afe585d6", "metadata": {}, "outputs": [], @@ -760,16 +869,16 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "f7ac6188", + "cell_type": "markdown", + "id": "1c171c8c", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "### load data into tensor for model to process" + ] }, { "cell_type": "code", - "execution_count": 278, + "execution_count": 303, "id": "0788b045", "metadata": {}, "outputs": [], @@ -779,7 +888,7 @@ }, { "cell_type": "code", - "execution_count": 279, + "execution_count": 308, "id": "8bfaa8bd", "metadata": {}, "outputs": [ @@ -789,7 +898,7 @@ "[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]" ] }, - "execution_count": 279, + "execution_count": 308, "metadata": {}, "output_type": "execute_result" } @@ -798,10 +907,18 @@ "data" ] }, + { + "cell_type": "markdown", + "id": "99132b3d", + "metadata": {}, + "source": [ + "### check device once again (prob not needed)" + ] + }, { "cell_type": "code", - "execution_count": 280, - "id": "dd0e7310", + "execution_count": 309, + "id": "b8c50c8c", "metadata": {}, "outputs": [ { @@ -810,7 +927,7 @@ "device(type='cuda')" ] }, - "execution_count": 280, + "execution_count": 309, "metadata": {}, "output_type": "execute_result" } @@ -821,87 +938,21 @@ ] }, { - "cell_type": "code", - "execution_count": 281, - "id": "1728f0fd", - "metadata": {}, - "outputs": [], - "source": [ - "for d in data:\n", - " d.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 282, - "id": "49d27864", - "metadata": {}, - "outputs": [], - "source": [ - "best_model.eval()\n", - "for batch in data:\n", - " output = best_model(batch.to(device), src_mask)" - ] - }, - { - "cell_type": "code", - "execution_count": 283, - "id": "a3404169", - "metadata": {}, - "outputs": [], - "source": [ - "result_np = []\n", - "pred_np = output.cpu().detach().numpy()\n", - "for el in pred_np:\n", - " result_np.append(el)" - ] - }, - { - "cell_type": "code", - "execution_count": 284, - "id": "c7064c0c", + "cell_type": "markdown", + "id": "05766f6b", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[array([[-0.5258564 , 4.6108465 , 4.8358154 , ..., -0.46871045,\n", - " -0.04386039, -0.13068362],\n", - " [-0.35341978, 8.821181 , 9.57295 , ..., -1.2820313 ,\n", - " -0.989242 , -0.15542248],\n", - " [-0.39717233, 5.7111125 , 6.4295497 , ..., -0.27339834,\n", - " -1.5333815 , 0.16042188]], dtype=float32),\n", - " array([[-0.5291481 , 4.6452312 , 4.7958803 , ..., -0.4642661 ,\n", - " -0.04427804, -0.12225106],\n", - " [-0.35347146, 8.824585 , 9.5098095 , ..., -1.2693769 ,\n", - " -0.97772634, -0.13521233],\n", - " [-0.39733842, 5.693817 , 6.368334 , ..., -0.26423275,\n", - " -1.527182 , 0.16518843]], dtype=float32),\n", - " array([[-0.53349733, 4.6777644 , 4.7953978 , ..., -0.4346264 ,\n", - " -0.03433151, -0.11583059],\n", - " [-0.3695631 , 8.82613 , 9.477964 , ..., -1.2404828 ,\n", - " -0.9594678 , -0.11550146],\n", - " [-0.41203997, 5.699034 , 6.341655 , ..., -0.24370295,\n", - " -1.5179048 , 0.16230991]], dtype=float32)]" - ] - }, - "execution_count": 284, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "result_np" + "### define predict function" ] }, { "cell_type": "code", - "execution_count": 285, - "id": "679e2316", + "execution_count": 317, + "id": "0475bcc9", "metadata": {}, "outputs": [], "source": [ - "def predict(input_line, n_predictions=1):\n", + "def predict(input_line, n_predictions=3):\n", " print('\\n> %s' % input_line)\n", " with torch.no_grad():\n", " output = best_model(input_line.to(device), src_mask)\n", @@ -920,31 +971,64 @@ " print(vocab.lookup_token(predict_token_index))\n", " #print(category_index)\n", " #print('(%.2f) %s' % (value, all_categories[category_index]))\n", - " #predictions.append([value, all_categories[category_index]])" + " predictions.append(vocab.lookup_token(predict_token_index))\n", + " return predictions" + ] + }, + { + "cell_type": "markdown", + "id": "8ad2f64b", + "metadata": {}, + "source": [ + "### Execute prediction and display predicted values" ] }, { "cell_type": "code", - "execution_count": 286, - "id": "03389137", + "execution_count": 318, + "id": "55b73ea1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]\n", "\n", "> tensor([ 3, 555, 16])\n", "tumors\n", + "the\n", + "the\n", + "The brain is\n", + "Possible continuations:\n", + "0 : tumors\n", + "0 : the\n", + "0 : the\n", "\n", "> tensor([ 3, 76, 16])\n", - "cancer\n" + "cancer\n", + "most\n", + "the\n", + "The lung is\n", + "Possible continuations:\n", + "0 : cancer\n", + "0 : most\n", + "0 : the\n" ] } ], "source": [ + "print(data)\n", + "count = 0\n", + "num_of_pred = 3\n", "for d in data:\n", - " predict(d)" + " predictions = predict(d, num_of_pred)\n", + " print(input_batch[count])\n", + " print(\"Possible continuations:\")\n", + " for j in range(len(predictions)):\n", + " print(i, \": \", predictions[j])\n", + " count = count + 1\n", + " " ] } ],