Fork is_composite_or_commposite_value into Keras to split a dependency on private symbols
PiperOrigin-RevId: 339120462 Change-Id: I8c48668daf0daf330b4d34bc58b4a20f1a9de67c
This commit is contained in:
parent
47b9d66592
commit
b8fcda3cd1
@ -36,11 +36,13 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import composite_tensor_utils
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend as K
|
||||
@ -55,11 +57,22 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
|
||||
|
||||
def is_composite_or_composite_value(tensor):
|
||||
"""Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
|
||||
# TODO(b/125094323): This should be isinstance(CompositeTensor) or
|
||||
# isinstance(CompositeTensorValue) once we support that.
|
||||
return isinstance(
|
||||
tensor,
|
||||
(composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue,
|
||||
ragged_tensor_value.RaggedTensorValue))
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Aggregator(object):
|
||||
"""Abstract base class used to aggregate batch-level outputs of a loop.
|
||||
@ -156,8 +169,7 @@ class ConcatAggregator(Aggregator):
|
||||
use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
|
||||
|
||||
def create(self, batch_element):
|
||||
self.composite = composite_tensor_utils.is_composite_or_composite_value(
|
||||
batch_element)
|
||||
self.composite = is_composite_or_composite_value(batch_element)
|
||||
|
||||
def aggregate(self, batch_element, batch_start=None, batch_end=None):
|
||||
|
||||
@ -313,12 +325,11 @@ class OutputsAggregator(Aggregator):
|
||||
# SparseTensorValue is a named tuple which nest will flatten, so we need
|
||||
# to guard it to properly handle the structure.
|
||||
self._structure = nest.get_traverse_shallow_structure(
|
||||
lambda x: not composite_tensor_utils.is_composite_or_composite_value(x),
|
||||
batch_outs)
|
||||
lambda x: not is_composite_or_composite_value(x), batch_outs)
|
||||
batch_outs = nest.flatten_up_to(self._structure, batch_outs)
|
||||
|
||||
for batch_element in batch_outs:
|
||||
if composite_tensor_utils.is_composite_or_composite_value(batch_element):
|
||||
if is_composite_or_composite_value(batch_element):
|
||||
# If the output is not a ndarray, it will be either a composite tensor
|
||||
# or a composite tensor's Value object. In either case, we can't
|
||||
# allocate an array to hold the object - we'll handle it later.
|
||||
@ -399,7 +410,7 @@ def standardize_single_array(x, expected_shape=None):
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
if composite_tensor_utils.is_composite_or_composite_value(x):
|
||||
if is_composite_or_composite_value(x):
|
||||
return x
|
||||
|
||||
if isinstance(x, int):
|
||||
@ -517,7 +528,7 @@ def standardize_input_data(data,
|
||||
if not tensorshape:
|
||||
continue
|
||||
data_shape = tuple(tensorshape.as_list())
|
||||
elif composite_tensor_utils.is_composite_or_composite_value(data[i]):
|
||||
elif is_composite_or_composite_value(data[i]):
|
||||
tensorshape = composite_tensor_utils.get_shape(data[i])
|
||||
data_shape = tuple(tensorshape.as_list())
|
||||
else:
|
||||
@ -610,8 +621,7 @@ def check_array_lengths(inputs, targets, weights=None):
|
||||
"""
|
||||
|
||||
def is_tensor_or_composite_tensor(x):
|
||||
return tensor_util.is_tensor(
|
||||
x) or composite_tensor_utils.is_composite_or_composite_value(x)
|
||||
return tensor_util.is_tensor(x) or is_composite_or_composite_value(x)
|
||||
|
||||
def set_of_lengths(x):
|
||||
# Returns a set with the variation between
|
||||
|
@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import composite_tensor_utils
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -2495,7 +2494,7 @@ class Model(training_lib.Model):
|
||||
# users should explicitly add composite tensor inputs to their subclassed
|
||||
# models.
|
||||
for input_tensor in processed_inputs:
|
||||
if composite_tensor_utils.is_composite_or_composite_value(input_tensor):
|
||||
if training_utils_v1.is_composite_or_composite_value(input_tensor):
|
||||
# TODO(b/132691975): Document subclass-model CT input handling.
|
||||
raise ValueError(
|
||||
'All SparseTensor and RaggedTensor inputs must be explicitly '
|
||||
|
Loading…
Reference in New Issue
Block a user