diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 94406f69254..c10d77f7efb 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -14,6 +14,7 @@ py_library( srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py index bdc1dc4ed87..15eed2aea6c 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py @@ -176,6 +176,87 @@ class BasicSamplingDecoderTest(test.TestCase): self.assertAllEqual(expected_step_next_inputs, sess_results["step_next_inputs"]) + def testStepWithScheduledEmbeddingTrainingSampler(self): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + vocabulary_size = 10 + + with self.test_session() as sess: + inputs = np.random.randn( + batch_size, max_time, input_depth).astype(np.float32) + embeddings = np.random.randn( + vocabulary_size, input_depth).astype(np.float32) + half = constant_op.constant(0.5) + cell = core_rnn_cell.LSTMCell(vocabulary_size) + sampler = sampling_decoder.ScheduledEmbeddingTrainingSampler( + inputs=inputs, sequence_length=sequence_length, + embedding=embeddings, sampling_probability=half, + time_major=False) + my_decoder = sampling_decoder.BasicSamplingDecoder( + cell=cell, + sampler=sampler, + initial_state=cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size)) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + sampling_decoder.SamplingDecoderOutput(vocabulary_size, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (first_finished, first_inputs, first_state) = my_decoder.initialize() + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput)) + self.assertEqual((batch_size, vocabulary_size), + step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[1].get_shape()) + + sess.run(variables.global_variables_initializer()) + sess_results = sess.run({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + self.assertAllEqual([False, False, False, False, True], + sess_results["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + sess_results["step_finished"]) + sample_ids = sess_results["step_outputs"].sample_id + batch_where_not_sampling = np.where(sample_ids == -1) + batch_where_sampling = np.where(sample_ids > -1) + self.assertAllEqual( + sess_results["step_next_inputs"][batch_where_sampling], + embeddings[sample_ids[batch_where_sampling]]) + self.assertAllEqual( + sess_results["step_next_inputs"][batch_where_not_sampling], + np.squeeze(inputs[batch_where_not_sampling, 1])) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py index 0c86c00411f..8f15cdc27aa 100644 --- a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py @@ -24,6 +24,7 @@ import collections import six +from tensorflow.contrib.distributions.python.ops import categorical from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import dtypes @@ -31,14 +32,17 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest __all__ = [ "Sampler", "SamplingDecoderOutput", "BasicSamplingDecoder", "BasicTrainingSampler", "GreedyEmbeddingSampler", "CustomSampler", + "ScheduledEmbeddingTrainingSampler", ] _transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access @@ -205,69 +209,167 @@ class BasicTrainingSampler(Sampler): Returned sample_ids are the argmax of the RNN output logits. """ - def __init__(self, inputs, sequence_length, time_major=False): + def __init__(self, inputs, sequence_length, time_major=False, name=None): """Initializer. Args: inputs: A (structure of) input tensors. sequence_length: An int32 vector tensor. - time_major: Python bool. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + name: Name scope for any created operations. Raises: ValueError: if `sequence_length` is not a 1D tensor. """ - inputs = ops.convert_to_tensor(inputs, name="inputs") - if not time_major: - inputs = nest.map_structure(_transpose_batch_time, inputs) + with ops.name_scope( + name, "BasicTrainingSampler", [inputs, sequence_length]): + inputs = ops.convert_to_tensor(inputs, name="inputs") + if not time_major: + inputs = nest.map_structure(_transpose_batch_time, inputs) - def _unstack_ta(inp): - return tensor_array_ops.TensorArray( - dtype=inp.dtype, size=array_ops.shape(inp)[0], - element_shape=inp.get_shape()[1:]).unstack(inp) + def _unstack_ta(inp): + return tensor_array_ops.TensorArray( + dtype=inp.dtype, size=array_ops.shape(inp)[0], + element_shape=inp.get_shape()[1:]).unstack(inp) - self._input_tas = nest.map_structure(_unstack_ta, inputs) - self._sequence_length = ops.convert_to_tensor( - sequence_length, name="sequence_length") - if self._sequence_length.get_shape().ndims != 1: - raise ValueError( - "Expected sequence_length to be a vector, but received shape: %s" % - self._sequence_length.get_shape()) + self._input_tas = nest.map_structure(_unstack_ta, inputs) + self._sequence_length = ops.convert_to_tensor( + sequence_length, name="sequence_length") + if self._sequence_length.get_shape().ndims != 1: + raise ValueError( + "Expected sequence_length to be a vector, but received shape: %s" % + self._sequence_length.get_shape()) - self._zero_inputs = nest.map_structure( - lambda inp: array_ops.zeros_like(inp[0, :]), inputs) + self._zero_inputs = nest.map_structure( + lambda inp: array_ops.zeros_like(inp[0, :]), inputs) - self._batch_size = array_ops.size(sequence_length) + self._batch_size = array_ops.size(sequence_length) @property def batch_size(self): return self._batch_size def initialize(self, name=None): - finished = math_ops.equal(0, self._sequence_length) - all_finished = math_ops.reduce_all(finished) - next_inputs = control_flow_ops.cond( - all_finished, lambda: self._zero_inputs, - lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas)) - return (finished, next_inputs) + with ops.name_scope(name, "BasicTrainingSamplerInitialize"): + finished = math_ops.equal(0, self._sequence_length) + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, lambda: self._zero_inputs, + lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas)) + return (finished, next_inputs) def sample(self, time, outputs, name=None, **unused_kwargs): - del time # unused by sample_fn - sample_ids = math_ops.cast( - math_ops.argmax(outputs, axis=-1), dtypes.int32) - return sample_ids + with ops.name_scope(name, "BasicTrainingSamplerSample", [time, outputs]): + sample_ids = math_ops.cast( + math_ops.argmax(outputs, axis=-1), dtypes.int32) + return sample_ids def next_inputs(self, time, outputs, state, name=None, **unused_kwargs): """next_inputs_fn for BasicTrainingSampler.""" - del outputs # unused by next_inputs_fn - next_time = time + 1 - finished = (next_time >= self._sequence_length) - all_finished = math_ops.reduce_all(finished) - def read_from_ta(inp): - return inp.read(next_time) - next_inputs = control_flow_ops.cond( - all_finished, lambda: self._zero_inputs, - lambda: nest.map_structure(read_from_ta, self._input_tas)) - return (finished, next_inputs, state) + with ops.name_scope( + name, "BasicTrainingSamplerNextInputs", [time, outputs, state]): + next_time = time + 1 + finished = (next_time >= self._sequence_length) + all_finished = math_ops.reduce_all(finished) + def read_from_ta(inp): + return inp.read(next_time) + next_inputs = control_flow_ops.cond( + all_finished, lambda: self._zero_inputs, + lambda: nest.map_structure(read_from_ta, self._input_tas)) + return (finished, next_inputs, state) + + +class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler): + """A training sampler that adds scheduled sampling. + + Returns -1s for sample_ids where no sampling took place; valid sample id + values elsewhere. + """ + + def __init__(self, inputs, sequence_length, embedding, sampling_probability, + time_major=False, seed=None, scheduling_seed=None, name=None): + """Initializer. + + Args: + inputs: A (structure of) input tensors. + sequence_length: An int32 vector tensor. + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + sampling_probability: A 0D `float32` tensor: the probability of sampling + categorically from the output ids instead of reading directly from the + inputs. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + seed: The sampling seed. + scheduling_seed: The schedule decision rule sampling seed. + name: Name scope for any created operations. + + Raises: + ValueError: if `sampling_probability` is not a scalar or vector. + """ + with ops.name_scope(name, "ScheduledEmbeddingSamplingWrapper", + [embedding, sampling_probability]): + if callable(embedding): + self._embedding_fn = embedding + else: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + self._sampling_probability = ops.convert_to_tensor( + sampling_probability, name="sampling_probability") + if self._sampling_probability.get_shape().ndims not in (0, 1): + raise ValueError( + "sampling_probability must be either a scalar or a vector. " + "saw shape: %s" % (self._sampling_probability.get_shape())) + self._seed = seed + self._scheduling_seed = scheduling_seed + super(ScheduledEmbeddingTrainingSampler, self).__init__( + inputs=inputs, + sequence_length=sequence_length, + time_major=time_major, + name=name) + + def initialize(self, name=None): + return super(ScheduledEmbeddingTrainingSampler, self).initialize( + name=name) + + def sample(self, time, outputs, state, name=None): + with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample", + [time, outputs, state]): + # Return -1s where we did not sample, and sample_ids elsewhere + select_sample_noise = random_ops.random_uniform( + [self.batch_size], seed=self._scheduling_seed) + select_sample = (self._sampling_probability > select_sample_noise) + sample_id_sampler = categorical.Categorical(logits=outputs) + return array_ops.where( + select_sample, + sample_id_sampler.sample(seed=self._seed), + array_ops.tile([-1], [self.batch_size])) + + def next_inputs(self, time, outputs, state, sample_ids, name=None): + with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample", + [time, outputs, state, sample_ids]): + (finished, base_next_inputs, state) = ( + super(ScheduledEmbeddingTrainingSampler, self).next_inputs( + time=time, outputs=outputs, state=state, sample_ids=sample_ids, + name=name)) + + def maybe_sample(): + where_sampling = math_ops.cast(sample_ids > -1, dtypes.int32) + _, sample_ids_sampling = data_flow_ops.dynamic_partition( + sample_ids, where_sampling, 2) + inputs_not_sampling, _ = data_flow_ops.dynamic_partition( + base_next_inputs, where_sampling, 2) + partitioned_indices = data_flow_ops.dynamic_partition( + math_ops.range(array_ops.size(where_sampling)), where_sampling, 2) + sampled_next_inputs = self._embedding_fn(sample_ids_sampling) + return data_flow_ops.dynamic_stitch( + partitioned_indices, (inputs_not_sampling, sampled_next_inputs)) + + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, lambda: base_next_inputs, maybe_sample) + return (finished, next_inputs, state) class GreedyEmbeddingSampler(Sampler):