Merge pull request #21299 from SayHiRay:patch-1
PiperOrigin-RevId: 209632789
This commit is contained in:
commit
e787c15ae8
@ -1056,7 +1056,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
|
" attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
|
" predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
|
||||||
" result.append(index_word[predicted_id])\n",
|
" result.append(index_word[predicted_id])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if index_word[predicted_id] == '<end>':\n",
|
" if index_word[predicted_id] == '<end>':\n",
|
||||||
|
@ -610,7 +610,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # using a multinomial distribution to predict the word returned by the model\n",
|
" # using a multinomial distribution to predict the word returned by the model\n",
|
||||||
" predictions = predictions / temperature\n",
|
" predictions = predictions / temperature\n",
|
||||||
" predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
|
" predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # We pass the predicted word as the next input to the model\n",
|
" # We pass the predicted word as the next input to the model\n",
|
||||||
" # along with the previous hidden state\n",
|
" # along with the previous hidden state\n",
|
||||||
|
@ -466,10 +466,10 @@
|
|||||||
" # passing the concatenated vector to the GRU\n",
|
" # passing the concatenated vector to the GRU\n",
|
||||||
" output, state = self.gru(x)\n",
|
" output, state = self.gru(x)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # output shape == (batch_size * max_length, hidden_size)\n",
|
" # output shape == (batch_size * 1, hidden_size)\n",
|
||||||
" output = tf.reshape(output, (-1, output.shape[2]))\n",
|
" output = tf.reshape(output, (-1, output.shape[2]))\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # output shape == (batch_size * max_length, vocab)\n",
|
" # output shape == (batch_size * 1, vocab)\n",
|
||||||
" x = self.fc(output)\n",
|
" x = self.fc(output)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" return x, state, attention_weights\n",
|
" return x, state, attention_weights\n",
|
||||||
@ -677,7 +677,7 @@
|
|||||||
" attention_weights = tf.reshape(attention_weights, (-1, ))\n",
|
" attention_weights = tf.reshape(attention_weights, (-1, ))\n",
|
||||||
" attention_plot[t] = attention_weights.numpy()\n",
|
" attention_plot[t] = attention_weights.numpy()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
|
" predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" result += targ_lang.idx2word[predicted_id] + ' '\n",
|
" result += targ_lang.idx2word[predicted_id] + ' '\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user