Merge pull request #33441 from Tetragramm:Issue_#33178
PiperOrigin-RevId: 290781306 Change-Id: I3fb111c482a9fc27263ccf9b39206747c0d76569
This commit is contained in:
commit
9f2aa61811
@ -54,7 +54,6 @@ class Wrapper(Layer):
|
||||
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
|
||||
# the inner layer has update ops that depend on its inputs (as opposed
|
||||
# to the inputs to the Wrapper layer).
|
||||
self._input_map = {}
|
||||
super(Wrapper, self).__init__(**kwargs)
|
||||
|
||||
def build(self, input_shape=None):
|
||||
@ -256,22 +255,19 @@ class TimeDistributed(Wrapper):
|
||||
if not input_length:
|
||||
input_length = array_ops.shape(inputs)[1]
|
||||
inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
|
||||
# Shape: (num_samples * timesteps, ...). And track the
|
||||
# transformation in self._input_map.
|
||||
input_uid = generic_utils.object_list_uid(inputs)
|
||||
# Shape: (num_samples * timesteps, ...).
|
||||
inputs = array_ops.reshape(inputs, inner_input_shape)
|
||||
self._input_map[input_uid] = inputs
|
||||
# (num_samples * timesteps, ...)
|
||||
if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
|
||||
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
|
||||
kwargs['mask'] = K.reshape(mask, inner_mask_shape)
|
||||
|
||||
y = self.layer(inputs, **kwargs)
|
||||
|
||||
# Shape: (num_samples, timesteps, ...)
|
||||
output_shape = self.compute_output_shape(input_shape).as_list()
|
||||
output_shape = self._get_shape_tuple((-1, input_length), y, 1,
|
||||
output_shape[2:])
|
||||
|
||||
y = array_ops.reshape(y, output_shape)
|
||||
|
||||
return y
|
||||
@ -321,9 +317,11 @@ class TimeDistributed(Wrapper):
|
||||
if inner_mask is not None:
|
||||
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
|
||||
inner_mask = K.reshape(inner_mask, inner_mask_shape)
|
||||
input_uid = generic_utils.object_list_uid(inputs)
|
||||
inner_inputs = self._input_map.get(input_uid, inputs)
|
||||
output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
|
||||
#Reshape inputs because that's what call() does
|
||||
#and we aren't saving the shape in an dict anymore
|
||||
inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
|
||||
inputs = array_ops.reshape(inputs, inner_input_shape)
|
||||
output_mask = self.layer.compute_mask(inputs, inner_mask)
|
||||
if output_mask is None:
|
||||
if mask is None:
|
||||
return None
|
||||
|
Loading…
Reference in New Issue
Block a user