|
@ -435,12 +435,12 @@ |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"id": "da8fb12b", |
|
|
|
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "3b78cc08", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
|
|
|
"source": [] |
|
|
|
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define pos encoder" |
|
|
|
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
@ -471,6 +471,14 @@ |
|
|
" return self.dropout(x)\n" |
|
|
" return self.dropout(x)\n" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "0adefcce", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define function to create batches of data and create batches" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 48, |
|
|
"execution_count": 48, |
|
@ -509,6 +517,14 @@ |
|
|
"test_data = batchify(test_data, eval_batch_size)" |
|
|
"test_data = batchify(test_data, eval_batch_size)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "4f407ad0", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define function to get batch" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 50, |
|
|
"execution_count": 50, |
|
@ -533,6 +549,14 @@ |
|
|
" return data, target" |
|
|
" return data, target" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "7ee28c38", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define parameters and init model" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 51, |
|
|
"execution_count": 51, |
|
@ -549,10 +573,18 @@ |
|
|
"model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)" |
|
|
"model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "51f2400a", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### init optimizer, loss, scheduler etc." |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 52, |
|
|
|
|
|
"id": "50ab3fb6", |
|
|
|
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"id": "b9a04e07", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
@ -562,7 +594,24 @@ |
|
|
"criterion = nn.CrossEntropyLoss()\n", |
|
|
"criterion = nn.CrossEntropyLoss()\n", |
|
|
"lr = 5.0 # learning rate\n", |
|
|
"lr = 5.0 # learning rate\n", |
|
|
"optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", |
|
|
"optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", |
|
|
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)\n", |
|
|
|
|
|
|
|
|
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "07317af8", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define train function" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 52, |
|
|
|
|
|
"id": "50ab3fb6", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
"\n", |
|
|
"\n", |
|
|
"def train(model: nn.Module) -> None:\n", |
|
|
"def train(model: nn.Module) -> None:\n", |
|
|
" model.train() # turn on train mode\n", |
|
|
" model.train() # turn on train mode\n", |
|
@ -595,8 +644,24 @@ |
|
|
" f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n", |
|
|
" f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n", |
|
|
" f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n", |
|
|
" f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n", |
|
|
" total_loss = 0\n", |
|
|
" total_loss = 0\n", |
|
|
" start_time = time.time()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
|
|
|
" start_time = time.time()" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "23709949", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define evaluate function" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 289, |
|
|
|
|
|
"id": "689bd4ea", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
"def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n", |
|
|
"def evaluate(model: nn.Module, eval_data: Tensor) -> float:\n", |
|
|
" model.eval() # turn on evaluation mode\n", |
|
|
" model.eval() # turn on evaluation mode\n", |
|
|
" total_loss = 0.\n", |
|
|
" total_loss = 0.\n", |
|
@ -613,11 +678,21 @@ |
|
|
" return total_loss / (len(eval_data) - 1)" |
|
|
" return total_loss / (len(eval_data) - 1)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "d7c6a1e0", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### now we can start training the model while saving best one" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 53, |
|
|
"execution_count": 53, |
|
|
"id": "09c4d4ce", |
|
|
"id": "09c4d4ce", |
|
|
"metadata": {}, |
|
|
|
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"scrolled": true |
|
|
|
|
|
}, |
|
|
"outputs": [ |
|
|
"outputs": [ |
|
|
{ |
|
|
{ |
|
|
"name": "stdout", |
|
|
"name": "stdout", |
|
@ -689,11 +764,21 @@ |
|
|
" scheduler.step()" |
|
|
" scheduler.step()" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "565b5aa4", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### print info about best model after training" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 54, |
|
|
"execution_count": 54, |
|
|
"id": "12fdd0aa", |
|
|
"id": "12fdd0aa", |
|
|
"metadata": {}, |
|
|
|
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"scrolled": true |
|
|
|
|
|
}, |
|
|
"outputs": [ |
|
|
"outputs": [ |
|
|
{ |
|
|
{ |
|
|
"name": "stdout", |
|
|
"name": "stdout", |
|
@ -714,6 +799,14 @@ |
|
|
"print('=' * 89)" |
|
|
"print('=' * 89)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "12031065", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Now we can try to predict based on trained model" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"cell_type": "markdown", |
|
|
"id": "e685d3e1", |
|
|
"id": "e685d3e1", |
|
@ -724,7 +817,7 @@ |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 274, |
|
|
|
|
|
|
|
|
"execution_count": 300, |
|
|
"id": "cfb30fe0", |
|
|
"id": "cfb30fe0", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
@ -735,9 +828,17 @@ |
|
|
"]" |
|
|
"]" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "054ada71", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define source mask for model" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 275, |
|
|
|
|
|
|
|
|
"execution_count": 301, |
|
|
"id": "305853e8", |
|
|
"id": "305853e8", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
@ -746,9 +847,17 @@ |
|
|
"src_mask = generate_square_subsequent_mask(bptt).to(device)" |
|
|
"src_mask = generate_square_subsequent_mask(bptt).to(device)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "4635a73e", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### define iterator for predict batch and init to generator" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 276, |
|
|
|
|
|
|
|
|
"execution_count": 302, |
|
|
"id": "afe585d6", |
|
|
"id": "afe585d6", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
@ -760,16 +869,16 @@ |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"id": "f7ac6188", |
|
|
|
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "1c171c8c", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
|
|
|
"source": [] |
|
|
|
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### load data into tensor for model to process" |
|
|
|
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 278, |
|
|
|
|
|
|
|
|
"execution_count": 303, |
|
|
"id": "0788b045", |
|
|
"id": "0788b045", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
@ -779,7 +888,7 @@ |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 279, |
|
|
|
|
|
|
|
|
"execution_count": 308, |
|
|
"id": "8bfaa8bd", |
|
|
"id": "8bfaa8bd", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [ |
|
|
"outputs": [ |
|
@ -789,7 +898,7 @@ |
|
|
"[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]" |
|
|
"[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
"execution_count": 279, |
|
|
|
|
|
|
|
|
"execution_count": 308, |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"output_type": "execute_result" |
|
|
"output_type": "execute_result" |
|
|
} |
|
|
} |
|
@ -798,10 +907,18 @@ |
|
|
"data" |
|
|
"data" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "99132b3d", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### check device once again (prob not needed)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 280, |
|
|
|
|
|
"id": "dd0e7310", |
|
|
|
|
|
|
|
|
"execution_count": 309, |
|
|
|
|
|
"id": "b8c50c8c", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [ |
|
|
"outputs": [ |
|
|
{ |
|
|
{ |
|
@ -810,7 +927,7 @@ |
|
|
"device(type='cuda')" |
|
|
"device(type='cuda')" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
"execution_count": 280, |
|
|
|
|
|
|
|
|
"execution_count": 309, |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"output_type": "execute_result" |
|
|
"output_type": "execute_result" |
|
|
} |
|
|
} |
|
@ -821,87 +938,21 @@ |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 281, |
|
|
|
|
|
"id": "1728f0fd", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"for d in data:\n", |
|
|
|
|
|
" d.to(device)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 282, |
|
|
|
|
|
"id": "49d27864", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"best_model.eval()\n", |
|
|
|
|
|
"for batch in data:\n", |
|
|
|
|
|
" output = best_model(batch.to(device), src_mask)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 283, |
|
|
|
|
|
"id": "a3404169", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"result_np = []\n", |
|
|
|
|
|
"pred_np = output.cpu().detach().numpy()\n", |
|
|
|
|
|
"for el in pred_np:\n", |
|
|
|
|
|
" result_np.append(el)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 284, |
|
|
|
|
|
"id": "c7064c0c", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"[array([[-0.5258564 , 4.6108465 , 4.8358154 , ..., -0.46871045,\n", |
|
|
|
|
|
" -0.04386039, -0.13068362],\n", |
|
|
|
|
|
" [-0.35341978, 8.821181 , 9.57295 , ..., -1.2820313 ,\n", |
|
|
|
|
|
" -0.989242 , -0.15542248],\n", |
|
|
|
|
|
" [-0.39717233, 5.7111125 , 6.4295497 , ..., -0.27339834,\n", |
|
|
|
|
|
" -1.5333815 , 0.16042188]], dtype=float32),\n", |
|
|
|
|
|
" array([[-0.5291481 , 4.6452312 , 4.7958803 , ..., -0.4642661 ,\n", |
|
|
|
|
|
" -0.04427804, -0.12225106],\n", |
|
|
|
|
|
" [-0.35347146, 8.824585 , 9.5098095 , ..., -1.2693769 ,\n", |
|
|
|
|
|
" -0.97772634, -0.13521233],\n", |
|
|
|
|
|
" [-0.39733842, 5.693817 , 6.368334 , ..., -0.26423275,\n", |
|
|
|
|
|
" -1.527182 , 0.16518843]], dtype=float32),\n", |
|
|
|
|
|
" array([[-0.53349733, 4.6777644 , 4.7953978 , ..., -0.4346264 ,\n", |
|
|
|
|
|
" -0.03433151, -0.11583059],\n", |
|
|
|
|
|
" [-0.3695631 , 8.82613 , 9.477964 , ..., -1.2404828 ,\n", |
|
|
|
|
|
" -0.9594678 , -0.11550146],\n", |
|
|
|
|
|
" [-0.41203997, 5.699034 , 6.341655 , ..., -0.24370295,\n", |
|
|
|
|
|
" -1.5179048 , 0.16230991]], dtype=float32)]" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": 284, |
|
|
|
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "05766f6b", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"output_type": "execute_result" |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
"source": [ |
|
|
"result_np" |
|
|
|
|
|
|
|
|
"### define predict function" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 285, |
|
|
|
|
|
"id": "679e2316", |
|
|
|
|
|
|
|
|
"execution_count": 317, |
|
|
|
|
|
"id": "0475bcc9", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"def predict(input_line, n_predictions=1):\n", |
|
|
|
|
|
|
|
|
"def predict(input_line, n_predictions=3):\n", |
|
|
" print('\\n> %s' % input_line)\n", |
|
|
" print('\\n> %s' % input_line)\n", |
|
|
" with torch.no_grad():\n", |
|
|
" with torch.no_grad():\n", |
|
|
" output = best_model(input_line.to(device), src_mask)\n", |
|
|
" output = best_model(input_line.to(device), src_mask)\n", |
|
@ -920,31 +971,64 @@ |
|
|
" print(vocab.lookup_token(predict_token_index))\n", |
|
|
" print(vocab.lookup_token(predict_token_index))\n", |
|
|
" #print(category_index)\n", |
|
|
" #print(category_index)\n", |
|
|
" #print('(%.2f) %s' % (value, all_categories[category_index]))\n", |
|
|
" #print('(%.2f) %s' % (value, all_categories[category_index]))\n", |
|
|
" #predictions.append([value, all_categories[category_index]])" |
|
|
|
|
|
|
|
|
" predictions.append(vocab.lookup_token(predict_token_index))\n", |
|
|
|
|
|
" return predictions" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"id": "8ad2f64b", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### Execute prediction and display predicted values" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 286, |
|
|
|
|
|
"id": "03389137", |
|
|
|
|
|
|
|
|
"execution_count": 318, |
|
|
|
|
|
"id": "55b73ea1", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [ |
|
|
"outputs": [ |
|
|
{ |
|
|
{ |
|
|
"name": "stdout", |
|
|
"name": "stdout", |
|
|
"output_type": "stream", |
|
|
"output_type": "stream", |
|
|
"text": [ |
|
|
"text": [ |
|
|
|
|
|
"[tensor([ 3, 555, 16]), tensor([ 3, 76, 16])]\n", |
|
|
"\n", |
|
|
"\n", |
|
|
"> tensor([ 3, 555, 16])\n", |
|
|
"> tensor([ 3, 555, 16])\n", |
|
|
"tumors\n", |
|
|
"tumors\n", |
|
|
|
|
|
"the\n", |
|
|
|
|
|
"the\n", |
|
|
|
|
|
"The brain is\n", |
|
|
|
|
|
"Possible continuations:\n", |
|
|
|
|
|
"0 : tumors\n", |
|
|
|
|
|
"0 : the\n", |
|
|
|
|
|
"0 : the\n", |
|
|
"\n", |
|
|
"\n", |
|
|
"> tensor([ 3, 76, 16])\n", |
|
|
"> tensor([ 3, 76, 16])\n", |
|
|
"cancer\n" |
|
|
|
|
|
|
|
|
"cancer\n", |
|
|
|
|
|
"most\n", |
|
|
|
|
|
"the\n", |
|
|
|
|
|
"The lung is\n", |
|
|
|
|
|
"Possible continuations:\n", |
|
|
|
|
|
"0 : cancer\n", |
|
|
|
|
|
"0 : most\n", |
|
|
|
|
|
"0 : the\n" |
|
|
] |
|
|
] |
|
|
} |
|
|
} |
|
|
], |
|
|
], |
|
|
"source": [ |
|
|
"source": [ |
|
|
|
|
|
"print(data)\n", |
|
|
|
|
|
"count = 0\n", |
|
|
|
|
|
"num_of_pred = 3\n", |
|
|
"for d in data:\n", |
|
|
"for d in data:\n", |
|
|
" predict(d)" |
|
|
|
|
|
|
|
|
" predictions = predict(d, num_of_pred)\n", |
|
|
|
|
|
" print(input_batch[count])\n", |
|
|
|
|
|
" print(\"Possible continuations:\")\n", |
|
|
|
|
|
" for j in range(len(predictions)):\n", |
|
|
|
|
|
" print(i, \": \", predictions[j])\n", |
|
|
|
|
|
" count = count + 1\n", |
|
|
|
|
|
" " |
|
|
] |
|
|
] |
|
|
} |
|
|
} |
|
|
], |
|
|
], |
|
|