Fork is_composite_or_commposite_value into Keras to split a dependency on private symbols

PiperOrigin-RevId: 339120462
Change-Id: I8c48668daf0daf330b4d34bc58b4a20f1a9de67c
This commit is contained in:
Tomer Kaftan 2020-10-26 14:38:11 -07:00 committed by TensorFlower Gardener
parent 47b9d66592
commit b8fcda3cd1
2 changed files with 20 additions and 11 deletions

View File

@ -36,11 +36,13 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import composite_tensor_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
@ -55,11 +57,22 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
def is_composite_or_composite_value(tensor):
"""Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
# TODO(b/125094323): This should be isinstance(CompositeTensor) or
# isinstance(CompositeTensorValue) once we support that.
return isinstance(
tensor,
(composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue,
ragged_tensor_value.RaggedTensorValue))
@six.add_metaclass(abc.ABCMeta)
class Aggregator(object):
"""Abstract base class used to aggregate batch-level outputs of a loop.
@ -156,8 +169,7 @@ class ConcatAggregator(Aggregator):
use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
def create(self, batch_element):
self.composite = composite_tensor_utils.is_composite_or_composite_value(
batch_element)
self.composite = is_composite_or_composite_value(batch_element)
def aggregate(self, batch_element, batch_start=None, batch_end=None):
@ -313,12 +325,11 @@ class OutputsAggregator(Aggregator):
# SparseTensorValue is a named tuple which nest will flatten, so we need
# to guard it to properly handle the structure.
self._structure = nest.get_traverse_shallow_structure(
lambda x: not composite_tensor_utils.is_composite_or_composite_value(x),
batch_outs)
lambda x: not is_composite_or_composite_value(x), batch_outs)
batch_outs = nest.flatten_up_to(self._structure, batch_outs)
for batch_element in batch_outs:
if composite_tensor_utils.is_composite_or_composite_value(batch_element):
if is_composite_or_composite_value(batch_element):
# If the output is not a ndarray, it will be either a composite tensor
# or a composite tensor's Value object. In either case, we can't
# allocate an array to hold the object - we'll handle it later.
@ -399,7 +410,7 @@ def standardize_single_array(x, expected_shape=None):
if x is None:
return None
if composite_tensor_utils.is_composite_or_composite_value(x):
if is_composite_or_composite_value(x):
return x
if isinstance(x, int):
@ -517,7 +528,7 @@ def standardize_input_data(data,
if not tensorshape:
continue
data_shape = tuple(tensorshape.as_list())
elif composite_tensor_utils.is_composite_or_composite_value(data[i]):
elif is_composite_or_composite_value(data[i]):
tensorshape = composite_tensor_utils.get_shape(data[i])
data_shape = tuple(tensorshape.as_list())
else:
@ -610,8 +621,7 @@ def check_array_lengths(inputs, targets, weights=None):
"""
def is_tensor_or_composite_tensor(x):
return tensor_util.is_tensor(
x) or composite_tensor_utils.is_composite_or_composite_value(x)
return tensor_util.is_tensor(x) or is_composite_or_composite_value(x)
def set_of_lengths(x):
# Returns a set with the variation between

View File

@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import composite_tensor_utils
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -2495,7 +2494,7 @@ class Model(training_lib.Model):
# users should explicitly add composite tensor inputs to their subclassed
# models.
for input_tensor in processed_inputs:
if composite_tensor_utils.is_composite_or_composite_value(input_tensor):
if training_utils_v1.is_composite_or_composite_value(input_tensor):
# TODO(b/132691975): Document subclass-model CT input handling.
raise ValueError(
'All SparseTensor and RaggedTensor inputs must be explicitly '