Update keras.backend.placeholder to preserve shape (e.g. batch size) information when constructing a ragged placeholder.
PiperOrigin-RevId: 266043329
This commit is contained in:
parent
a1a5f93073
commit
127fdf9c7c
tensorflow/python/keras
@ -69,7 +69,7 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
@ -1038,16 +1038,13 @@ def placeholder(shape=None,
|
||||
ragged_rank = 0
|
||||
for i in range(1, len(shape)):
|
||||
if shape[i] is None:
|
||||
ragged_rank += 1
|
||||
else:
|
||||
break
|
||||
value_shape = shape[(ragged_rank + 1):]
|
||||
|
||||
x = ragged_factory_ops.placeholder(
|
||||
dtype=dtype,
|
||||
ragged_rank=ragged_rank,
|
||||
value_shape=value_shape,
|
||||
name=name)
|
||||
ragged_rank = i
|
||||
type_spec = ragged_tensor.RaggedTensorSpec(
|
||||
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
|
||||
def tensor_spec_to_placeholder(tensorspec):
|
||||
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
|
||||
x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
|
||||
expand_composites=True)
|
||||
else:
|
||||
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
||||
return x
|
||||
|
@ -510,7 +510,10 @@ class RaggedTensorInputTest(keras_parameterized.TestCase,
|
||||
# Prepare the model to test.
|
||||
input_name = get_input_name(use_dict)
|
||||
model_input = input_layer.Input(
|
||||
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32)
|
||||
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32,
|
||||
batch_size=2)
|
||||
self.assertIsInstance(model_input, ragged_tensor.RaggedTensor)
|
||||
self.assertEqual(model_input.shape.as_list(), [2, None, None])
|
||||
layers = [ToDense(default_value=-1)]
|
||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||
model.compile(
|
||||
|
Loading…
Reference in New Issue
Block a user