Fix sparse kerastensors to maintain dense shape information after converting to a placeholder.
PiperOrigin-RevId: 316538468 Change-Id: I8e53a7e96067a8b7edd3f57cd8a8a89eb912824b
This commit is contained in:
parent
a8950d70bf
commit
7d76bc4b60
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import type_spec as type_spec_module
|
from tensorflow.python.framework import type_spec as type_spec_module
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
@ -210,10 +211,20 @@ class _KerasTensorIterator(object):
|
||||||
def keras_tensor_to_placeholder(x):
|
def keras_tensor_to_placeholder(x):
|
||||||
"""TODO(kaftan): Docstring."""
|
"""TODO(kaftan): Docstring."""
|
||||||
if isinstance(x, KerasTensor):
|
if isinstance(x, KerasTensor):
|
||||||
def tensor_spec_to_placeholder(tensorspec):
|
spec = x.type_spec
|
||||||
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
|
if isinstance(spec, sparse_tensor.SparseTensorSpec):
|
||||||
ph = nest.map_structure(tensor_spec_to_placeholder, x.type_spec,
|
# nest.map_structure loses dense shape information for sparse tensors.
|
||||||
expand_composites=True)
|
# So, we special-case sparse placeholder creation.
|
||||||
|
# This only preserves shape information for top-level sparse tensors;
|
||||||
|
# not for sparse tensors that are nested inside another composite
|
||||||
|
# tensor.
|
||||||
|
return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape)
|
||||||
|
|
||||||
|
def component_to_placeholder(component):
|
||||||
|
return array_ops.placeholder(component.dtype, component.shape)
|
||||||
|
|
||||||
|
ph = nest.map_structure(
|
||||||
|
component_to_placeholder, spec, expand_composites=True)
|
||||||
return ph
|
return ph
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -603,8 +603,7 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types()
|
@keras_parameterized.run_with_all_model_types()
|
||||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
skip_keras_tensors=True)
|
|
||||||
class CompositeTensorModelPredictTest(keras_parameterized.TestCase):
|
class CompositeTensorModelPredictTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
def _normalize_shape(self, shape):
|
def _normalize_shape(self, shape):
|
||||||
|
|
|
@ -3184,6 +3184,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
|
||||||
# `SparseTensor`
|
# `SparseTensor`
|
||||||
dense_shape_default = tensor_shape.TensorShape(
|
dense_shape_default = tensor_shape.TensorShape(
|
||||||
tuple(None if dim == -1 else dim for dim in shape))
|
tuple(None if dim == -1 else dim for dim in shape))
|
||||||
|
shape = tuple(tensor_shape.dimension_value(dim) for dim in shape)
|
||||||
shape = tuple(-1 if dim is None else dim for dim in shape)
|
shape = tuple(-1 if dim is None else dim for dim in shape)
|
||||||
shape = ops.convert_to_tensor(
|
shape = ops.convert_to_tensor(
|
||||||
shape, dtype=dtypes.int64, name=default_shape_name)
|
shape, dtype=dtypes.int64, name=default_shape_name)
|
||||||
|
|
Loading…
Reference in New Issue