Add new scheduled sampler to new seq2seq api.
Change: 146519142
This commit is contained in:
parent
56f266acf2
commit
6abf7cfee9
@ -14,6 +14,7 @@ py_library(
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/rnn:rnn_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -176,6 +176,87 @@ class BasicSamplingDecoderTest(test.TestCase):
|
||||
self.assertAllEqual(expected_step_next_inputs,
|
||||
sess_results["step_next_inputs"])
|
||||
|
||||
def testStepWithScheduledEmbeddingTrainingSampler(self):
|
||||
sequence_length = [3, 4, 3, 1, 0]
|
||||
batch_size = 5
|
||||
max_time = 8
|
||||
input_depth = 7
|
||||
vocabulary_size = 10
|
||||
|
||||
with self.test_session() as sess:
|
||||
inputs = np.random.randn(
|
||||
batch_size, max_time, input_depth).astype(np.float32)
|
||||
embeddings = np.random.randn(
|
||||
vocabulary_size, input_depth).astype(np.float32)
|
||||
half = constant_op.constant(0.5)
|
||||
cell = core_rnn_cell.LSTMCell(vocabulary_size)
|
||||
sampler = sampling_decoder.ScheduledEmbeddingTrainingSampler(
|
||||
inputs=inputs, sequence_length=sequence_length,
|
||||
embedding=embeddings, sampling_probability=half,
|
||||
time_major=False)
|
||||
my_decoder = sampling_decoder.BasicSamplingDecoder(
|
||||
cell=cell,
|
||||
sampler=sampler,
|
||||
initial_state=cell.zero_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size))
|
||||
output_size = my_decoder.output_size
|
||||
output_dtype = my_decoder.output_dtype
|
||||
self.assertEqual(
|
||||
sampling_decoder.SamplingDecoderOutput(vocabulary_size,
|
||||
tensor_shape.TensorShape([])),
|
||||
output_size)
|
||||
self.assertEqual(
|
||||
sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32),
|
||||
output_dtype)
|
||||
|
||||
(first_finished, first_inputs, first_state) = my_decoder.initialize()
|
||||
(step_outputs, step_state, step_next_inputs,
|
||||
step_finished) = my_decoder.step(
|
||||
constant_op.constant(0), first_inputs, first_state)
|
||||
batch_size_t = my_decoder.batch_size
|
||||
|
||||
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(
|
||||
isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput))
|
||||
self.assertEqual((batch_size, vocabulary_size),
|
||||
step_outputs[0].get_shape())
|
||||
self.assertEqual((batch_size,), step_outputs[1].get_shape())
|
||||
self.assertEqual((batch_size, vocabulary_size),
|
||||
first_state[0].get_shape())
|
||||
self.assertEqual((batch_size, vocabulary_size),
|
||||
first_state[1].get_shape())
|
||||
self.assertEqual((batch_size, vocabulary_size),
|
||||
step_state[0].get_shape())
|
||||
self.assertEqual((batch_size, vocabulary_size),
|
||||
step_state[1].get_shape())
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess_results = sess.run({
|
||||
"batch_size": batch_size_t,
|
||||
"first_finished": first_finished,
|
||||
"first_inputs": first_inputs,
|
||||
"first_state": first_state,
|
||||
"step_outputs": step_outputs,
|
||||
"step_state": step_state,
|
||||
"step_next_inputs": step_next_inputs,
|
||||
"step_finished": step_finished
|
||||
})
|
||||
|
||||
self.assertAllEqual([False, False, False, False, True],
|
||||
sess_results["first_finished"])
|
||||
self.assertAllEqual([False, False, False, True, True],
|
||||
sess_results["step_finished"])
|
||||
sample_ids = sess_results["step_outputs"].sample_id
|
||||
batch_where_not_sampling = np.where(sample_ids == -1)
|
||||
batch_where_sampling = np.where(sample_ids > -1)
|
||||
self.assertAllEqual(
|
||||
sess_results["step_next_inputs"][batch_where_sampling],
|
||||
embeddings[sample_ids[batch_where_sampling]])
|
||||
self.assertAllEqual(
|
||||
sess_results["step_next_inputs"][batch_where_not_sampling],
|
||||
np.squeeze(inputs[batch_where_not_sampling, 1]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -24,6 +24,7 @@ import collections
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import categorical
|
||||
from tensorflow.contrib.rnn import core_rnn_cell
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -31,14 +32,17 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
__all__ = [
|
||||
"Sampler", "SamplingDecoderOutput", "BasicSamplingDecoder",
|
||||
"BasicTrainingSampler", "GreedyEmbeddingSampler", "CustomSampler",
|
||||
"ScheduledEmbeddingTrainingSampler",
|
||||
]
|
||||
|
||||
_transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access
|
||||
@ -205,69 +209,167 @@ class BasicTrainingSampler(Sampler):
|
||||
Returned sample_ids are the argmax of the RNN output logits.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, sequence_length, time_major=False):
|
||||
def __init__(self, inputs, sequence_length, time_major=False, name=None):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
inputs: A (structure of) input tensors.
|
||||
sequence_length: An int32 vector tensor.
|
||||
time_major: Python bool.
|
||||
time_major: Python bool. Whether the tensors in `inputs` are time major.
|
||||
If `False` (default), they are assumed to be batch major.
|
||||
name: Name scope for any created operations.
|
||||
|
||||
Raises:
|
||||
ValueError: if `sequence_length` is not a 1D tensor.
|
||||
"""
|
||||
inputs = ops.convert_to_tensor(inputs, name="inputs")
|
||||
if not time_major:
|
||||
inputs = nest.map_structure(_transpose_batch_time, inputs)
|
||||
with ops.name_scope(
|
||||
name, "BasicTrainingSampler", [inputs, sequence_length]):
|
||||
inputs = ops.convert_to_tensor(inputs, name="inputs")
|
||||
if not time_major:
|
||||
inputs = nest.map_structure(_transpose_batch_time, inputs)
|
||||
|
||||
def _unstack_ta(inp):
|
||||
return tensor_array_ops.TensorArray(
|
||||
dtype=inp.dtype, size=array_ops.shape(inp)[0],
|
||||
element_shape=inp.get_shape()[1:]).unstack(inp)
|
||||
def _unstack_ta(inp):
|
||||
return tensor_array_ops.TensorArray(
|
||||
dtype=inp.dtype, size=array_ops.shape(inp)[0],
|
||||
element_shape=inp.get_shape()[1:]).unstack(inp)
|
||||
|
||||
self._input_tas = nest.map_structure(_unstack_ta, inputs)
|
||||
self._sequence_length = ops.convert_to_tensor(
|
||||
sequence_length, name="sequence_length")
|
||||
if self._sequence_length.get_shape().ndims != 1:
|
||||
raise ValueError(
|
||||
"Expected sequence_length to be a vector, but received shape: %s" %
|
||||
self._sequence_length.get_shape())
|
||||
self._input_tas = nest.map_structure(_unstack_ta, inputs)
|
||||
self._sequence_length = ops.convert_to_tensor(
|
||||
sequence_length, name="sequence_length")
|
||||
if self._sequence_length.get_shape().ndims != 1:
|
||||
raise ValueError(
|
||||
"Expected sequence_length to be a vector, but received shape: %s" %
|
||||
self._sequence_length.get_shape())
|
||||
|
||||
self._zero_inputs = nest.map_structure(
|
||||
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
|
||||
self._zero_inputs = nest.map_structure(
|
||||
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
|
||||
|
||||
self._batch_size = array_ops.size(sequence_length)
|
||||
self._batch_size = array_ops.size(sequence_length)
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def initialize(self, name=None):
|
||||
finished = math_ops.equal(0, self._sequence_length)
|
||||
all_finished = math_ops.reduce_all(finished)
|
||||
next_inputs = control_flow_ops.cond(
|
||||
all_finished, lambda: self._zero_inputs,
|
||||
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
|
||||
return (finished, next_inputs)
|
||||
with ops.name_scope(name, "BasicTrainingSamplerInitialize"):
|
||||
finished = math_ops.equal(0, self._sequence_length)
|
||||
all_finished = math_ops.reduce_all(finished)
|
||||
next_inputs = control_flow_ops.cond(
|
||||
all_finished, lambda: self._zero_inputs,
|
||||
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
|
||||
return (finished, next_inputs)
|
||||
|
||||
def sample(self, time, outputs, name=None, **unused_kwargs):
|
||||
del time # unused by sample_fn
|
||||
sample_ids = math_ops.cast(
|
||||
math_ops.argmax(outputs, axis=-1), dtypes.int32)
|
||||
return sample_ids
|
||||
with ops.name_scope(name, "BasicTrainingSamplerSample", [time, outputs]):
|
||||
sample_ids = math_ops.cast(
|
||||
math_ops.argmax(outputs, axis=-1), dtypes.int32)
|
||||
return sample_ids
|
||||
|
||||
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
|
||||
"""next_inputs_fn for BasicTrainingSampler."""
|
||||
del outputs # unused by next_inputs_fn
|
||||
next_time = time + 1
|
||||
finished = (next_time >= self._sequence_length)
|
||||
all_finished = math_ops.reduce_all(finished)
|
||||
def read_from_ta(inp):
|
||||
return inp.read(next_time)
|
||||
next_inputs = control_flow_ops.cond(
|
||||
all_finished, lambda: self._zero_inputs,
|
||||
lambda: nest.map_structure(read_from_ta, self._input_tas))
|
||||
return (finished, next_inputs, state)
|
||||
with ops.name_scope(
|
||||
name, "BasicTrainingSamplerNextInputs", [time, outputs, state]):
|
||||
next_time = time + 1
|
||||
finished = (next_time >= self._sequence_length)
|
||||
all_finished = math_ops.reduce_all(finished)
|
||||
def read_from_ta(inp):
|
||||
return inp.read(next_time)
|
||||
next_inputs = control_flow_ops.cond(
|
||||
all_finished, lambda: self._zero_inputs,
|
||||
lambda: nest.map_structure(read_from_ta, self._input_tas))
|
||||
return (finished, next_inputs, state)
|
||||
|
||||
|
||||
class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler):
|
||||
"""A training sampler that adds scheduled sampling.
|
||||
|
||||
Returns -1s for sample_ids where no sampling took place; valid sample id
|
||||
values elsewhere.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, sequence_length, embedding, sampling_probability,
|
||||
time_major=False, seed=None, scheduling_seed=None, name=None):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
inputs: A (structure of) input tensors.
|
||||
sequence_length: An int32 vector tensor.
|
||||
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
|
||||
or the `params` argument for `embedding_lookup`.
|
||||
sampling_probability: A 0D `float32` tensor: the probability of sampling
|
||||
categorically from the output ids instead of reading directly from the
|
||||
inputs.
|
||||
time_major: Python bool. Whether the tensors in `inputs` are time major.
|
||||
If `False` (default), they are assumed to be batch major.
|
||||
seed: The sampling seed.
|
||||
scheduling_seed: The schedule decision rule sampling seed.
|
||||
name: Name scope for any created operations.
|
||||
|
||||
Raises:
|
||||
ValueError: if `sampling_probability` is not a scalar or vector.
|
||||
"""
|
||||
with ops.name_scope(name, "ScheduledEmbeddingSamplingWrapper",
|
||||
[embedding, sampling_probability]):
|
||||
if callable(embedding):
|
||||
self._embedding_fn = embedding
|
||||
else:
|
||||
self._embedding_fn = (
|
||||
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
|
||||
self._sampling_probability = ops.convert_to_tensor(
|
||||
sampling_probability, name="sampling_probability")
|
||||
if self._sampling_probability.get_shape().ndims not in (0, 1):
|
||||
raise ValueError(
|
||||
"sampling_probability must be either a scalar or a vector. "
|
||||
"saw shape: %s" % (self._sampling_probability.get_shape()))
|
||||
self._seed = seed
|
||||
self._scheduling_seed = scheduling_seed
|
||||
super(ScheduledEmbeddingTrainingSampler, self).__init__(
|
||||
inputs=inputs,
|
||||
sequence_length=sequence_length,
|
||||
time_major=time_major,
|
||||
name=name)
|
||||
|
||||
def initialize(self, name=None):
|
||||
return super(ScheduledEmbeddingTrainingSampler, self).initialize(
|
||||
name=name)
|
||||
|
||||
def sample(self, time, outputs, state, name=None):
|
||||
with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample",
|
||||
[time, outputs, state]):
|
||||
# Return -1s where we did not sample, and sample_ids elsewhere
|
||||
select_sample_noise = random_ops.random_uniform(
|
||||
[self.batch_size], seed=self._scheduling_seed)
|
||||
select_sample = (self._sampling_probability > select_sample_noise)
|
||||
sample_id_sampler = categorical.Categorical(logits=outputs)
|
||||
return array_ops.where(
|
||||
select_sample,
|
||||
sample_id_sampler.sample(seed=self._seed),
|
||||
array_ops.tile([-1], [self.batch_size]))
|
||||
|
||||
def next_inputs(self, time, outputs, state, sample_ids, name=None):
|
||||
with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample",
|
||||
[time, outputs, state, sample_ids]):
|
||||
(finished, base_next_inputs, state) = (
|
||||
super(ScheduledEmbeddingTrainingSampler, self).next_inputs(
|
||||
time=time, outputs=outputs, state=state, sample_ids=sample_ids,
|
||||
name=name))
|
||||
|
||||
def maybe_sample():
|
||||
where_sampling = math_ops.cast(sample_ids > -1, dtypes.int32)
|
||||
_, sample_ids_sampling = data_flow_ops.dynamic_partition(
|
||||
sample_ids, where_sampling, 2)
|
||||
inputs_not_sampling, _ = data_flow_ops.dynamic_partition(
|
||||
base_next_inputs, where_sampling, 2)
|
||||
partitioned_indices = data_flow_ops.dynamic_partition(
|
||||
math_ops.range(array_ops.size(where_sampling)), where_sampling, 2)
|
||||
sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
|
||||
return data_flow_ops.dynamic_stitch(
|
||||
partitioned_indices, (inputs_not_sampling, sampled_next_inputs))
|
||||
|
||||
all_finished = math_ops.reduce_all(finished)
|
||||
next_inputs = control_flow_ops.cond(
|
||||
all_finished, lambda: base_next_inputs, maybe_sample)
|
||||
return (finished, next_inputs, state)
|
||||
|
||||
|
||||
class GreedyEmbeddingSampler(Sampler):
|
||||
|
Loading…
x
Reference in New Issue
Block a user