Update keras.backend.placeholder to preserve shape (e.g. batch size) information when constructing a ragged placeholder.

PiperOrigin-RevId: 266043329
This commit is contained in:
Edward Loper 2019-08-28 18:56:08 -07:00 committed by TensorFlower Gardener
parent a1a5f93073
commit 127fdf9c7c
2 changed files with 12 additions and 12 deletions

View File

@ -69,7 +69,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_module from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import object_identity from tensorflow.python.util import object_identity
@ -1038,16 +1038,13 @@ def placeholder(shape=None,
ragged_rank = 0 ragged_rank = 0
for i in range(1, len(shape)): for i in range(1, len(shape)):
if shape[i] is None: if shape[i] is None:
ragged_rank += 1 ragged_rank = i
else: type_spec = ragged_tensor.RaggedTensorSpec(
break shape=shape, dtype=dtype, ragged_rank=ragged_rank)
value_shape = shape[(ragged_rank + 1):] def tensor_spec_to_placeholder(tensorspec):
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
x = ragged_factory_ops.placeholder( x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
dtype=dtype, expand_composites=True)
ragged_rank=ragged_rank,
value_shape=value_shape,
name=name)
else: else:
x = array_ops.placeholder(dtype, shape=shape, name=name) x = array_ops.placeholder(dtype, shape=shape, name=name)
return x return x

View File

@ -510,7 +510,10 @@ class RaggedTensorInputTest(keras_parameterized.TestCase,
# Prepare the model to test. # Prepare the model to test.
input_name = get_input_name(use_dict) input_name = get_input_name(use_dict)
model_input = input_layer.Input( model_input = input_layer.Input(
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32) shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32,
batch_size=2)
self.assertIsInstance(model_input, ragged_tensor.RaggedTensor)
self.assertEqual(model_input.shape.as_list(), [2, None, None])
layers = [ToDense(default_value=-1)] layers = [ToDense(default_value=-1)]
model = get_model_from_layers_with_input(layers, model_input=model_input) model = get_model_from_layers_with_input(layers, model_input=model_input)
model.compile( model.compile(