Fix sparse kerastensors to maintain dense shape information after converting to a placeholder.
PiperOrigin-RevId: 316538468 Change-Id: I8e53a7e96067a8b7edd3f57cd8a8a89eb912824b
This commit is contained in:
parent
a8950d70bf
commit
7d76bc4b60
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import type_spec as type_spec_module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
@ -210,10 +211,20 @@ class _KerasTensorIterator(object):
|
|||
def keras_tensor_to_placeholder(x):
|
||||
"""TODO(kaftan): Docstring."""
|
||||
if isinstance(x, KerasTensor):
|
||||
def tensor_spec_to_placeholder(tensorspec):
|
||||
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
|
||||
ph = nest.map_structure(tensor_spec_to_placeholder, x.type_spec,
|
||||
expand_composites=True)
|
||||
spec = x.type_spec
|
||||
if isinstance(spec, sparse_tensor.SparseTensorSpec):
|
||||
# nest.map_structure loses dense shape information for sparse tensors.
|
||||
# So, we special-case sparse placeholder creation.
|
||||
# This only preserves shape information for top-level sparse tensors;
|
||||
# not for sparse tensors that are nested inside another composite
|
||||
# tensor.
|
||||
return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape)
|
||||
|
||||
def component_to_placeholder(component):
|
||||
return array_ops.placeholder(component.dtype, component.shape)
|
||||
|
||||
ph = nest.map_structure(
|
||||
component_to_placeholder, spec, expand_composites=True)
|
||||
return ph
|
||||
else:
|
||||
return x
|
||||
|
|
|
@ -603,8 +603,7 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
|||
|
||||
|
||||
@keras_parameterized.run_with_all_model_types()
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
|
||||
skip_keras_tensors=True)
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class CompositeTensorModelPredictTest(keras_parameterized.TestCase):
|
||||
|
||||
def _normalize_shape(self, shape):
|
||||
|
|
|
@ -3184,6 +3184,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
|
|||
# `SparseTensor`
|
||||
dense_shape_default = tensor_shape.TensorShape(
|
||||
tuple(None if dim == -1 else dim for dim in shape))
|
||||
shape = tuple(tensor_shape.dimension_value(dim) for dim in shape)
|
||||
shape = tuple(-1 if dim is None else dim for dim in shape)
|
||||
shape = ops.convert_to_tensor(
|
||||
shape, dtype=dtypes.int64, name=default_shape_name)
|
||||
|
|
Loading…
Reference in New Issue