Update keras.backend.placeholder to preserve shape (e.g. batch size) information when constructing a ragged placeholder.

PiperOrigin-RevId: 266043329
This commit is contained in:
Edward Loper 2019-08-28 18:56:08 -07:00 committed by TensorFlower Gardener
parent a1a5f93073
commit 127fdf9c7c
2 changed files with 12 additions and 12 deletions
tensorflow/python/keras

View File

@ -69,7 +69,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
@ -1038,16 +1038,13 @@ def placeholder(shape=None,
ragged_rank = 0
for i in range(1, len(shape)):
if shape[i] is None:
ragged_rank += 1
else:
break
value_shape = shape[(ragged_rank + 1):]
x = ragged_factory_ops.placeholder(
dtype=dtype,
ragged_rank=ragged_rank,
value_shape=value_shape,
name=name)
ragged_rank = i
type_spec = ragged_tensor.RaggedTensorSpec(
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
def tensor_spec_to_placeholder(tensorspec):
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
expand_composites=True)
else:
x = array_ops.placeholder(dtype, shape=shape, name=name)
return x

View File

@ -510,7 +510,10 @@ class RaggedTensorInputTest(keras_parameterized.TestCase,
# Prepare the model to test.
input_name = get_input_name(use_dict)
model_input = input_layer.Input(
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32)
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32,
batch_size=2)
self.assertIsInstance(model_input, ragged_tensor.RaggedTensor)
self.assertEqual(model_input.shape.as_list(), [2, None, None])
layers = [ToDense(default_value=-1)]
model = get_model_from_layers_with_input(layers, model_input=model_input)
model.compile(