diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 36f9a444cc4..29b36ef3ba4 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -51,10 +51,6 @@ class Wrapper(Layer): def __init__(self, layer, **kwargs): assert isinstance(layer, Layer) self.layer = layer - # Tracks mapping of Wrapper inputs to inner layer inputs. Useful when - # the inner layer has update ops that depend on its inputs (as opposed - # to the inputs to the Wrapper layer). - self._input_map = {} super(Wrapper, self).__init__(**kwargs) def build(self, input_shape=None): @@ -258,9 +254,7 @@ class TimeDistributed(Wrapper): inner_input_shape = self._get_shape_tuple((-1,), inputs, 2) # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. - input_uid = generic_utils.object_list_uid(inputs) inputs = array_ops.reshape(inputs, inner_input_shape) - self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None: inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) @@ -314,15 +308,17 @@ class TimeDistributed(Wrapper): # cases need to call the layer.compute_mask when input_mask is None: # Masking layer and Embedding layer with mask_zero input_shape = K.int_shape(inputs) - if input_shape[0]: - # batch size matters, we currently do not handle mask explicitly + if input_shape[0] and not self._always_use_reshape or isinstance( + inputs, ragged_tensor.RaggedTensor): + # batch size matters, we currently do not handle mask explicitly, or if + # the layer always uses reshape approach, or the input is a ragged tensor. return mask inner_mask = mask if inner_mask is not None: inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) inner_mask = K.reshape(inner_mask, inner_mask_shape) - input_uid = generic_utils.object_list_uid(inputs) - inner_inputs = self._input_map.get(input_uid, inputs) + inner_input_shape = self._get_shape_tuple((-1,), inputs, 2) + inner_inputs = array_ops.reshape(inputs, inner_input_shape) output_mask = self.layer.compute_mask(inner_inputs, inner_mask) if output_mask is None: if mask is None: diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index d5f21135b9d..fc855ee6428 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -186,7 +186,6 @@ class TimeDistributedTest(keras_parameterized.TestCase): y = model.predict(np.random.random((10, 3, 2))) self.assertAllClose(np.mean(y), 0., atol=1e-1, rtol=1e-1) - @tf_test_util.run_v1_only(reason='b/148248386') def test_TimeDistributed_batchnorm(self): with self.cached_session(): # test that wrapped BN updates still work. @@ -206,8 +205,6 @@ class TimeDistributedTest(keras_parameterized.TestCase): # Assert that mean and variance changed. assert not np.array_equal(td.get_weights()[2], np.array([0, 0])) assert not np.array_equal(td.get_weights()[3], np.array([1, 1])) - # Verify input_map has one mapping from inputs to reshaped inputs. - self.assertEqual(len(td._input_map.keys()), 1) def test_TimeDistributed_trainable(self): # test layers that need learning_phase to be set