Merge pull request #21299 from SayHiRay:patch-1

PiperOrigin-RevId: 209632789
This commit is contained in:
TensorFlower Gardener 2018-08-21 11:48:51 -07:00
commit e787c15ae8
3 changed files with 5 additions and 5 deletions

View File

@ -1056,7 +1056,7 @@
"\n",
" attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
"\n",
" predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
" predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
" result.append(index_word[predicted_id])\n",
"\n",
" if index_word[predicted_id] == '<end>':\n",

View File

@ -610,7 +610,7 @@
"\n",
" # using a multinomial distribution to predict the word returned by the model\n",
" predictions = predictions / temperature\n",
" predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
" predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
" \n",
" # We pass the predicted word as the next input to the model\n",
" # along with the previous hidden state\n",

View File

@ -466,10 +466,10 @@
" # passing the concatenated vector to the GRU\n",
" output, state = self.gru(x)\n",
" \n",
" # output shape == (batch_size * max_length, hidden_size)\n",
" # output shape == (batch_size * 1, hidden_size)\n",
" output = tf.reshape(output, (-1, output.shape[2]))\n",
" \n",
" # output shape == (batch_size * max_length, vocab)\n",
" # output shape == (batch_size * 1, vocab)\n",
" x = self.fc(output)\n",
" \n",
" return x, state, attention_weights\n",
@ -677,7 +677,7 @@
" attention_weights = tf.reshape(attention_weights, (-1, ))\n",
" attention_plot[t] = attention_weights.numpy()\n",
"\n",
" predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
" predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
"\n",
" result += targ_lang.idx2word[predicted_id] + ' '\n",
"\n",