remove some cpu specs from seq2seq model and the translate example (#2337)
indention
This commit is contained in:
parent
4dd5fef870
commit
d15feedccc
tensorflow
@ -83,14 +83,12 @@ class Seq2SeqModel(object):
|
|||||||
softmax_loss_function = None
|
softmax_loss_function = None
|
||||||
# Sampled softmax only makes sense if we sample less than vocabulary size.
|
# Sampled softmax only makes sense if we sample less than vocabulary size.
|
||||||
if num_samples > 0 and num_samples < self.target_vocab_size:
|
if num_samples > 0 and num_samples < self.target_vocab_size:
|
||||||
with tf.device("/cpu:0"):
|
|
||||||
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
|
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
|
||||||
w_t = tf.transpose(w)
|
w_t = tf.transpose(w)
|
||||||
b = tf.get_variable("proj_b", [self.target_vocab_size])
|
b = tf.get_variable("proj_b", [self.target_vocab_size])
|
||||||
output_projection = (w, b)
|
output_projection = (w, b)
|
||||||
|
|
||||||
def sampled_loss(inputs, labels):
|
def sampled_loss(inputs, labels):
|
||||||
with tf.device("/cpu:0"):
|
|
||||||
labels = tf.reshape(labels, [-1, 1])
|
labels = tf.reshape(labels, [-1, 1])
|
||||||
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
|
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
|
||||||
self.target_vocab_size)
|
self.target_vocab_size)
|
||||||
|
@ -260,7 +260,6 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
|
|||||||
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
|
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
|
||||||
|
|
||||||
with variable_scope.variable_scope(scope or "embedding_rnn_decoder"):
|
with variable_scope.variable_scope(scope or "embedding_rnn_decoder"):
|
||||||
with ops.device("/cpu:0"):
|
|
||||||
embedding = variable_scope.get_variable("embedding",
|
embedding = variable_scope.get_variable("embedding",
|
||||||
[num_symbols, embedding_size])
|
[num_symbols, embedding_size])
|
||||||
loop_function = _extract_argmax_and_embed(
|
loop_function = _extract_argmax_and_embed(
|
||||||
@ -398,7 +397,6 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
|
|||||||
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
|
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
|
||||||
|
|
||||||
with variable_scope.variable_scope(scope or "embedding_tied_rnn_seq2seq"):
|
with variable_scope.variable_scope(scope or "embedding_tied_rnn_seq2seq"):
|
||||||
with ops.device("/cpu:0"):
|
|
||||||
embedding = variable_scope.get_variable("embedding",
|
embedding = variable_scope.get_variable("embedding",
|
||||||
[num_symbols, embedding_size])
|
[num_symbols, embedding_size])
|
||||||
|
|
||||||
@ -636,7 +634,6 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
|
|||||||
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
|
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
|
||||||
|
|
||||||
with variable_scope.variable_scope(scope or "embedding_attention_decoder"):
|
with variable_scope.variable_scope(scope or "embedding_attention_decoder"):
|
||||||
with ops.device("/cpu:0"):
|
|
||||||
embedding = variable_scope.get_variable("embedding",
|
embedding = variable_scope.get_variable("embedding",
|
||||||
[num_symbols, embedding_size])
|
[num_symbols, embedding_size])
|
||||||
loop_function = _extract_argmax_and_embed(
|
loop_function = _extract_argmax_and_embed(
|
||||||
|
Loading…
Reference in New Issue
Block a user