Add new scheduled sampler to new seq2seq api.

Change: 146519142
This commit is contained in:
Eugene Brevdo 2017-02-03 14:43:42 -08:00 committed by TensorFlower Gardener
parent 56f266acf2
commit 6abf7cfee9
3 changed files with 223 additions and 39 deletions

View File

@ -14,6 +14,7 @@ py_library(
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/python:array_ops",

View File

@ -176,6 +176,87 @@ class BasicSamplingDecoderTest(test.TestCase):
self.assertAllEqual(expected_step_next_inputs,
sess_results["step_next_inputs"])
def testStepWithScheduledEmbeddingTrainingSampler(self):
sequence_length = [3, 4, 3, 1, 0]
batch_size = 5
max_time = 8
input_depth = 7
vocabulary_size = 10
with self.test_session() as sess:
inputs = np.random.randn(
batch_size, max_time, input_depth).astype(np.float32)
embeddings = np.random.randn(
vocabulary_size, input_depth).astype(np.float32)
half = constant_op.constant(0.5)
cell = core_rnn_cell.LSTMCell(vocabulary_size)
sampler = sampling_decoder.ScheduledEmbeddingTrainingSampler(
inputs=inputs, sequence_length=sequence_length,
embedding=embeddings, sampling_probability=half,
time_major=False)
my_decoder = sampling_decoder.BasicSamplingDecoder(
cell=cell,
sampler=sampler,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
output_size = my_decoder.output_size
output_dtype = my_decoder.output_dtype
self.assertEqual(
sampling_decoder.SamplingDecoderOutput(vocabulary_size,
tensor_shape.TensorShape([])),
output_size)
self.assertEqual(
sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32),
output_dtype)
(first_finished, first_inputs, first_state) = my_decoder.initialize()
(step_outputs, step_state, step_next_inputs,
step_finished) = my_decoder.step(
constant_op.constant(0), first_inputs, first_state)
batch_size_t = my_decoder.batch_size
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
self.assertTrue(
isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput))
self.assertEqual((batch_size, vocabulary_size),
step_outputs[0].get_shape())
self.assertEqual((batch_size,), step_outputs[1].get_shape())
self.assertEqual((batch_size, vocabulary_size),
first_state[0].get_shape())
self.assertEqual((batch_size, vocabulary_size),
first_state[1].get_shape())
self.assertEqual((batch_size, vocabulary_size),
step_state[0].get_shape())
self.assertEqual((batch_size, vocabulary_size),
step_state[1].get_shape())
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"batch_size": batch_size_t,
"first_finished": first_finished,
"first_inputs": first_inputs,
"first_state": first_state,
"step_outputs": step_outputs,
"step_state": step_state,
"step_next_inputs": step_next_inputs,
"step_finished": step_finished
})
self.assertAllEqual([False, False, False, False, True],
sess_results["first_finished"])
self.assertAllEqual([False, False, False, True, True],
sess_results["step_finished"])
sample_ids = sess_results["step_outputs"].sample_id
batch_where_not_sampling = np.where(sample_ids == -1)
batch_where_sampling = np.where(sample_ids > -1)
self.assertAllEqual(
sess_results["step_next_inputs"][batch_where_sampling],
embeddings[sample_ids[batch_where_sampling]])
self.assertAllEqual(
sess_results["step_next_inputs"][batch_where_not_sampling],
np.squeeze(inputs[batch_where_not_sampling, 1]))
if __name__ == "__main__":
test.main()

View File

@ -24,6 +24,7 @@ import collections
import six
from tensorflow.contrib.distributions.python.ops import categorical
from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
@ -31,14 +32,17 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
__all__ = [
"Sampler", "SamplingDecoderOutput", "BasicSamplingDecoder",
"BasicTrainingSampler", "GreedyEmbeddingSampler", "CustomSampler",
"ScheduledEmbeddingTrainingSampler",
]
_transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access
@ -205,69 +209,167 @@ class BasicTrainingSampler(Sampler):
Returned sample_ids are the argmax of the RNN output logits.
"""
def __init__(self, inputs, sequence_length, time_major=False):
def __init__(self, inputs, sequence_length, time_major=False, name=None):
"""Initializer.
Args:
inputs: A (structure of) input tensors.
sequence_length: An int32 vector tensor.
time_major: Python bool.
time_major: Python bool. Whether the tensors in `inputs` are time major.
If `False` (default), they are assumed to be batch major.
name: Name scope for any created operations.
Raises:
ValueError: if `sequence_length` is not a 1D tensor.
"""
inputs = ops.convert_to_tensor(inputs, name="inputs")
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)
with ops.name_scope(
name, "BasicTrainingSampler", [inputs, sequence_length]):
inputs = ops.convert_to_tensor(inputs, name="inputs")
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)
def _unstack_ta(inp):
return tensor_array_ops.TensorArray(
dtype=inp.dtype, size=array_ops.shape(inp)[0],
element_shape=inp.get_shape()[1:]).unstack(inp)
def _unstack_ta(inp):
return tensor_array_ops.TensorArray(
dtype=inp.dtype, size=array_ops.shape(inp)[0],
element_shape=inp.get_shape()[1:]).unstack(inp)
self._input_tas = nest.map_structure(_unstack_ta, inputs)
self._sequence_length = ops.convert_to_tensor(
sequence_length, name="sequence_length")
if self._sequence_length.get_shape().ndims != 1:
raise ValueError(
"Expected sequence_length to be a vector, but received shape: %s" %
self._sequence_length.get_shape())
self._input_tas = nest.map_structure(_unstack_ta, inputs)
self._sequence_length = ops.convert_to_tensor(
sequence_length, name="sequence_length")
if self._sequence_length.get_shape().ndims != 1:
raise ValueError(
"Expected sequence_length to be a vector, but received shape: %s" %
self._sequence_length.get_shape())
self._zero_inputs = nest.map_structure(
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
self._zero_inputs = nest.map_structure(
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
self._batch_size = array_ops.size(sequence_length)
self._batch_size = array_ops.size(sequence_length)
@property
def batch_size(self):
return self._batch_size
def initialize(self, name=None):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)
with ops.name_scope(name, "BasicTrainingSamplerInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)
def sample(self, time, outputs, name=None, **unused_kwargs):
del time # unused by sample_fn
sample_ids = math_ops.cast(
math_ops.argmax(outputs, axis=-1), dtypes.int32)
return sample_ids
with ops.name_scope(name, "BasicTrainingSamplerSample", [time, outputs]):
sample_ids = math_ops.cast(
math_ops.argmax(outputs, axis=-1), dtypes.int32)
return sample_ids
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for BasicTrainingSampler."""
del outputs # unused by next_inputs_fn
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
def read_from_ta(inp):
return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)
with ops.name_scope(
name, "BasicTrainingSamplerNextInputs", [time, outputs, state]):
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
def read_from_ta(inp):
return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)
class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler):
"""A training sampler that adds scheduled sampling.
Returns -1s for sample_ids where no sampling took place; valid sample id
values elsewhere.
"""
def __init__(self, inputs, sequence_length, embedding, sampling_probability,
time_major=False, seed=None, scheduling_seed=None, name=None):
"""Initializer.
Args:
inputs: A (structure of) input tensors.
sequence_length: An int32 vector tensor.
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`.
sampling_probability: A 0D `float32` tensor: the probability of sampling
categorically from the output ids instead of reading directly from the
inputs.
time_major: Python bool. Whether the tensors in `inputs` are time major.
If `False` (default), they are assumed to be batch major.
seed: The sampling seed.
scheduling_seed: The schedule decision rule sampling seed.
name: Name scope for any created operations.
Raises:
ValueError: if `sampling_probability` is not a scalar or vector.
"""
with ops.name_scope(name, "ScheduledEmbeddingSamplingWrapper",
[embedding, sampling_probability]):
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._sampling_probability = ops.convert_to_tensor(
sampling_probability, name="sampling_probability")
if self._sampling_probability.get_shape().ndims not in (0, 1):
raise ValueError(
"sampling_probability must be either a scalar or a vector. "
"saw shape: %s" % (self._sampling_probability.get_shape()))
self._seed = seed
self._scheduling_seed = scheduling_seed
super(ScheduledEmbeddingTrainingSampler, self).__init__(
inputs=inputs,
sequence_length=sequence_length,
time_major=time_major,
name=name)
def initialize(self, name=None):
return super(ScheduledEmbeddingTrainingSampler, self).initialize(
name=name)
def sample(self, time, outputs, state, name=None):
with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample",
[time, outputs, state]):
# Return -1s where we did not sample, and sample_ids elsewhere
select_sample_noise = random_ops.random_uniform(
[self.batch_size], seed=self._scheduling_seed)
select_sample = (self._sampling_probability > select_sample_noise)
sample_id_sampler = categorical.Categorical(logits=outputs)
return array_ops.where(
select_sample,
sample_id_sampler.sample(seed=self._seed),
array_ops.tile([-1], [self.batch_size]))
def next_inputs(self, time, outputs, state, sample_ids, name=None):
with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledEmbeddingTrainingSampler, self).next_inputs(
time=time, outputs=outputs, state=state, sample_ids=sample_ids,
name=name))
def maybe_sample():
where_sampling = math_ops.cast(sample_ids > -1, dtypes.int32)
_, sample_ids_sampling = data_flow_ops.dynamic_partition(
sample_ids, where_sampling, 2)
inputs_not_sampling, _ = data_flow_ops.dynamic_partition(
base_next_inputs, where_sampling, 2)
partitioned_indices = data_flow_ops.dynamic_partition(
math_ops.range(array_ops.size(where_sampling)), where_sampling, 2)
sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
return data_flow_ops.dynamic_stitch(
partitioned_indices, (inputs_not_sampling, sampled_next_inputs))
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)
class GreedyEmbeddingSampler(Sampler):