Merge pull request #33441 from Tetragramm:Issue_#33178

PiperOrigin-RevId: 290781306
Change-Id: I3fb111c482a9fc27263ccf9b39206747c0d76569
This commit is contained in:
TensorFlower Gardener 2020-01-21 11:11:13 -08:00
commit 9f2aa61811

View File

@ -54,7 +54,6 @@ class Wrapper(Layer):
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on its inputs (as opposed
# to the inputs to the Wrapper layer).
self._input_map = {}
super(Wrapper, self).__init__(**kwargs)
def build(self, input_shape=None):
@ -256,22 +255,19 @@ class TimeDistributed(Wrapper):
if not input_length:
input_length = array_ops.shape(inputs)[1]
inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
# Shape: (num_samples * timesteps, ...). And track the
# transformation in self._input_map.
input_uid = generic_utils.object_list_uid(inputs)
# Shape: (num_samples * timesteps, ...).
inputs = array_ops.reshape(inputs, inner_input_shape)
self._input_map[input_uid] = inputs
# (num_samples * timesteps, ...)
if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
kwargs['mask'] = K.reshape(mask, inner_mask_shape)
y = self.layer(inputs, **kwargs)
# Shape: (num_samples, timesteps, ...)
output_shape = self.compute_output_shape(input_shape).as_list()
output_shape = self._get_shape_tuple((-1, input_length), y, 1,
output_shape[2:])
y = array_ops.reshape(y, output_shape)
return y
@ -321,9 +317,11 @@ class TimeDistributed(Wrapper):
if inner_mask is not None:
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
inner_mask = K.reshape(inner_mask, inner_mask_shape)
input_uid = generic_utils.object_list_uid(inputs)
inner_inputs = self._input_map.get(input_uid, inputs)
output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
#Reshape inputs because that's what call() does
#and we aren't saving the shape in an dict anymore
inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
inputs = array_ops.reshape(inputs, inner_input_shape)
output_mask = self.layer.compute_mask(inputs, inner_mask)
if output_mask is None:
if mask is None:
return None