diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 00fef68728d..2ca2e7382e1 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -69,7 +69,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables as variables_module -from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util import object_identity @@ -1038,16 +1038,13 @@ def placeholder(shape=None, ragged_rank = 0 for i in range(1, len(shape)): if shape[i] is None: - ragged_rank += 1 - else: - break - value_shape = shape[(ragged_rank + 1):] - - x = ragged_factory_ops.placeholder( - dtype=dtype, - ragged_rank=ragged_rank, - value_shape=value_shape, - name=name) + ragged_rank = i + type_spec = ragged_tensor.RaggedTensorSpec( + shape=shape, dtype=dtype, ragged_rank=ragged_rank) + def tensor_spec_to_placeholder(tensorspec): + return array_ops.placeholder(tensorspec.dtype, tensorspec.shape) + x = nest.map_structure(tensor_spec_to_placeholder, type_spec, + expand_composites=True) else: x = array_ops.placeholder(dtype, shape=shape, name=name) return x diff --git a/tensorflow/python/keras/utils/composite_tensor_support_test.py b/tensorflow/python/keras/utils/composite_tensor_support_test.py index b5a1d514766..bfb56674f9b 100644 --- a/tensorflow/python/keras/utils/composite_tensor_support_test.py +++ b/tensorflow/python/keras/utils/composite_tensor_support_test.py @@ -510,7 +510,10 @@ class RaggedTensorInputTest(keras_parameterized.TestCase, # Prepare the model to test. input_name = get_input_name(use_dict) model_input = input_layer.Input( - shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32) + shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32, + batch_size=2) + self.assertIsInstance(model_input, ragged_tensor.RaggedTensor) + self.assertEqual(model_input.shape.as_list(), [2, None, None]) layers = [ToDense(default_value=-1)] model = get_model_from_layers_with_input(layers, model_input=model_input) model.compile(