Update keras.backend.placeholder to preserve shape (e.g. batch size) information when constructing a ragged placeholder.
PiperOrigin-RevId: 266043329
This commit is contained in:
parent
a1a5f93073
commit
127fdf9c7c
@ -69,7 +69,7 @@ from tensorflow.python.ops import state_ops
|
|||||||
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.ops import variables as variables_module
|
from tensorflow.python.ops import variables as variables_module
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import object_identity
|
from tensorflow.python.util import object_identity
|
||||||
@ -1038,16 +1038,13 @@ def placeholder(shape=None,
|
|||||||
ragged_rank = 0
|
ragged_rank = 0
|
||||||
for i in range(1, len(shape)):
|
for i in range(1, len(shape)):
|
||||||
if shape[i] is None:
|
if shape[i] is None:
|
||||||
ragged_rank += 1
|
ragged_rank = i
|
||||||
else:
|
type_spec = ragged_tensor.RaggedTensorSpec(
|
||||||
break
|
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
|
||||||
value_shape = shape[(ragged_rank + 1):]
|
def tensor_spec_to_placeholder(tensorspec):
|
||||||
|
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
|
||||||
x = ragged_factory_ops.placeholder(
|
x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
|
||||||
dtype=dtype,
|
expand_composites=True)
|
||||||
ragged_rank=ragged_rank,
|
|
||||||
value_shape=value_shape,
|
|
||||||
name=name)
|
|
||||||
else:
|
else:
|
||||||
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
||||||
return x
|
return x
|
||||||
|
@ -510,7 +510,10 @@ class RaggedTensorInputTest(keras_parameterized.TestCase,
|
|||||||
# Prepare the model to test.
|
# Prepare the model to test.
|
||||||
input_name = get_input_name(use_dict)
|
input_name = get_input_name(use_dict)
|
||||||
model_input = input_layer.Input(
|
model_input = input_layer.Input(
|
||||||
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32)
|
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32,
|
||||||
|
batch_size=2)
|
||||||
|
self.assertIsInstance(model_input, ragged_tensor.RaggedTensor)
|
||||||
|
self.assertEqual(model_input.shape.as_list(), [2, None, None])
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(
|
model.compile(
|
||||||
|
Loading…
Reference in New Issue
Block a user