remove some cpu specs from seq2seq model and the translate example (#2337)
indention
This commit is contained in:
parent
4dd5fef870
commit
d15feedccc
@ -83,17 +83,15 @@ class Seq2SeqModel(object):
|
||||
softmax_loss_function = None
|
||||
# Sampled softmax only makes sense if we sample less than vocabulary size.
|
||||
if num_samples > 0 and num_samples < self.target_vocab_size:
|
||||
with tf.device("/cpu:0"):
|
||||
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
|
||||
w_t = tf.transpose(w)
|
||||
b = tf.get_variable("proj_b", [self.target_vocab_size])
|
||||
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
|
||||
w_t = tf.transpose(w)
|
||||
b = tf.get_variable("proj_b", [self.target_vocab_size])
|
||||
output_projection = (w, b)
|
||||
|
||||
def sampled_loss(inputs, labels):
|
||||
with tf.device("/cpu:0"):
|
||||
labels = tf.reshape(labels, [-1, 1])
|
||||
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
|
||||
self.target_vocab_size)
|
||||
labels = tf.reshape(labels, [-1, 1])
|
||||
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
|
||||
self.target_vocab_size)
|
||||
softmax_loss_function = sampled_loss
|
||||
|
||||
# Create the internal multi-layer cell for our RNN.
|
||||
|
@ -260,9 +260,8 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
|
||||
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
|
||||
|
||||
with variable_scope.variable_scope(scope or "embedding_rnn_decoder"):
|
||||
with ops.device("/cpu:0"):
|
||||
embedding = variable_scope.get_variable("embedding",
|
||||
[num_symbols, embedding_size])
|
||||
embedding = variable_scope.get_variable("embedding",
|
||||
[num_symbols, embedding_size])
|
||||
loop_function = _extract_argmax_and_embed(
|
||||
embedding, output_projection,
|
||||
update_embedding_for_previous) if feed_previous else None
|
||||
@ -398,9 +397,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
|
||||
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
|
||||
|
||||
with variable_scope.variable_scope(scope or "embedding_tied_rnn_seq2seq"):
|
||||
with ops.device("/cpu:0"):
|
||||
embedding = variable_scope.get_variable("embedding",
|
||||
[num_symbols, embedding_size])
|
||||
embedding = variable_scope.get_variable("embedding",
|
||||
[num_symbols, embedding_size])
|
||||
|
||||
emb_encoder_inputs = [embedding_ops.embedding_lookup(embedding, x)
|
||||
for x in encoder_inputs]
|
||||
@ -636,9 +634,8 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
|
||||
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
|
||||
|
||||
with variable_scope.variable_scope(scope or "embedding_attention_decoder"):
|
||||
with ops.device("/cpu:0"):
|
||||
embedding = variable_scope.get_variable("embedding",
|
||||
[num_symbols, embedding_size])
|
||||
embedding = variable_scope.get_variable("embedding",
|
||||
[num_symbols, embedding_size])
|
||||
loop_function = _extract_argmax_and_embed(
|
||||
embedding, output_projection,
|
||||
update_embedding_for_previous) if feed_previous else None
|
||||
|
Loading…
Reference in New Issue
Block a user