remove some cpu specs from seq2seq model and the translate example (#2337)

indention
This commit is contained in:
Raingo 2016-05-24 20:05:07 -04:00 committed by Vijay Vasudevan
parent 4dd5fef870
commit d15feedccc
2 changed files with 12 additions and 17 deletions

View File

@ -83,17 +83,15 @@ class Seq2SeqModel(object):
softmax_loss_function = None
# Sampled softmax only makes sense if we sample less than vocabulary size.
if num_samples > 0 and num_samples < self.target_vocab_size:
with tf.device("/cpu:0"):
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
w_t = tf.transpose(w)
b = tf.get_variable("proj_b", [self.target_vocab_size])
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
w_t = tf.transpose(w)
b = tf.get_variable("proj_b", [self.target_vocab_size])
output_projection = (w, b)
def sampled_loss(inputs, labels):
with tf.device("/cpu:0"):
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
self.target_vocab_size)
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
self.target_vocab_size)
softmax_loss_function = sampled_loss
# Create the internal multi-layer cell for our RNN.

View File

@ -260,9 +260,8 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_rnn_decoder"):
with ops.device("/cpu:0"):
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
loop_function = _extract_argmax_and_embed(
embedding, output_projection,
update_embedding_for_previous) if feed_previous else None
@ -398,9 +397,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_tied_rnn_seq2seq"):
with ops.device("/cpu:0"):
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
emb_encoder_inputs = [embedding_ops.embedding_lookup(embedding, x)
for x in encoder_inputs]
@ -636,9 +634,8 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_attention_decoder"):
with ops.device("/cpu:0"):
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
loop_function = _extract_argmax_and_embed(
embedding, output_projection,
update_embedding_for_previous) if feed_previous else None