Fork is_composite_or_commposite_value into Keras to split a dependency on private symbols
PiperOrigin-RevId: 339120462 Change-Id: I8c48668daf0daf330b4d34bc58b4a20f1a9de67c
This commit is contained in:
parent
47b9d66592
commit
b8fcda3cd1
@ -36,11 +36,13 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import composite_tensor_utils
|
from tensorflow.python.framework import composite_tensor_utils
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import smart_cond
|
from tensorflow.python.framework import smart_cond
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
@ -55,11 +57,22 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import gen_array_ops
|
from tensorflow.python.ops import gen_array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.compat import collections_abc
|
from tensorflow.python.util.compat import collections_abc
|
||||||
|
|
||||||
|
|
||||||
|
def is_composite_or_composite_value(tensor):
|
||||||
|
"""Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
|
||||||
|
# TODO(b/125094323): This should be isinstance(CompositeTensor) or
|
||||||
|
# isinstance(CompositeTensorValue) once we support that.
|
||||||
|
return isinstance(
|
||||||
|
tensor,
|
||||||
|
(composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue,
|
||||||
|
ragged_tensor_value.RaggedTensorValue))
|
||||||
|
|
||||||
|
|
||||||
@six.add_metaclass(abc.ABCMeta)
|
@six.add_metaclass(abc.ABCMeta)
|
||||||
class Aggregator(object):
|
class Aggregator(object):
|
||||||
"""Abstract base class used to aggregate batch-level outputs of a loop.
|
"""Abstract base class used to aggregate batch-level outputs of a loop.
|
||||||
@ -156,8 +169,7 @@ class ConcatAggregator(Aggregator):
|
|||||||
use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
|
use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
|
||||||
|
|
||||||
def create(self, batch_element):
|
def create(self, batch_element):
|
||||||
self.composite = composite_tensor_utils.is_composite_or_composite_value(
|
self.composite = is_composite_or_composite_value(batch_element)
|
||||||
batch_element)
|
|
||||||
|
|
||||||
def aggregate(self, batch_element, batch_start=None, batch_end=None):
|
def aggregate(self, batch_element, batch_start=None, batch_end=None):
|
||||||
|
|
||||||
@ -313,12 +325,11 @@ class OutputsAggregator(Aggregator):
|
|||||||
# SparseTensorValue is a named tuple which nest will flatten, so we need
|
# SparseTensorValue is a named tuple which nest will flatten, so we need
|
||||||
# to guard it to properly handle the structure.
|
# to guard it to properly handle the structure.
|
||||||
self._structure = nest.get_traverse_shallow_structure(
|
self._structure = nest.get_traverse_shallow_structure(
|
||||||
lambda x: not composite_tensor_utils.is_composite_or_composite_value(x),
|
lambda x: not is_composite_or_composite_value(x), batch_outs)
|
||||||
batch_outs)
|
|
||||||
batch_outs = nest.flatten_up_to(self._structure, batch_outs)
|
batch_outs = nest.flatten_up_to(self._structure, batch_outs)
|
||||||
|
|
||||||
for batch_element in batch_outs:
|
for batch_element in batch_outs:
|
||||||
if composite_tensor_utils.is_composite_or_composite_value(batch_element):
|
if is_composite_or_composite_value(batch_element):
|
||||||
# If the output is not a ndarray, it will be either a composite tensor
|
# If the output is not a ndarray, it will be either a composite tensor
|
||||||
# or a composite tensor's Value object. In either case, we can't
|
# or a composite tensor's Value object. In either case, we can't
|
||||||
# allocate an array to hold the object - we'll handle it later.
|
# allocate an array to hold the object - we'll handle it later.
|
||||||
@ -399,7 +410,7 @@ def standardize_single_array(x, expected_shape=None):
|
|||||||
if x is None:
|
if x is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if composite_tensor_utils.is_composite_or_composite_value(x):
|
if is_composite_or_composite_value(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if isinstance(x, int):
|
if isinstance(x, int):
|
||||||
@ -517,7 +528,7 @@ def standardize_input_data(data,
|
|||||||
if not tensorshape:
|
if not tensorshape:
|
||||||
continue
|
continue
|
||||||
data_shape = tuple(tensorshape.as_list())
|
data_shape = tuple(tensorshape.as_list())
|
||||||
elif composite_tensor_utils.is_composite_or_composite_value(data[i]):
|
elif is_composite_or_composite_value(data[i]):
|
||||||
tensorshape = composite_tensor_utils.get_shape(data[i])
|
tensorshape = composite_tensor_utils.get_shape(data[i])
|
||||||
data_shape = tuple(tensorshape.as_list())
|
data_shape = tuple(tensorshape.as_list())
|
||||||
else:
|
else:
|
||||||
@ -610,8 +621,7 @@ def check_array_lengths(inputs, targets, weights=None):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def is_tensor_or_composite_tensor(x):
|
def is_tensor_or_composite_tensor(x):
|
||||||
return tensor_util.is_tensor(
|
return tensor_util.is_tensor(x) or is_composite_or_composite_value(x)
|
||||||
x) or composite_tensor_utils.is_composite_or_composite_value(x)
|
|
||||||
|
|
||||||
def set_of_lengths(x):
|
def set_of_lengths(x):
|
||||||
# Returns a set with the variation between
|
# Returns a set with the variation between
|
||||||
|
@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context
|
|||||||
from tensorflow.python.distribute import parameter_server_strategy
|
from tensorflow.python.distribute import parameter_server_strategy
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import composite_tensor_utils
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -2495,7 +2494,7 @@ class Model(training_lib.Model):
|
|||||||
# users should explicitly add composite tensor inputs to their subclassed
|
# users should explicitly add composite tensor inputs to their subclassed
|
||||||
# models.
|
# models.
|
||||||
for input_tensor in processed_inputs:
|
for input_tensor in processed_inputs:
|
||||||
if composite_tensor_utils.is_composite_or_composite_value(input_tensor):
|
if training_utils_v1.is_composite_or_composite_value(input_tensor):
|
||||||
# TODO(b/132691975): Document subclass-model CT input handling.
|
# TODO(b/132691975): Document subclass-model CT input handling.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'All SparseTensor and RaggedTensor inputs must be explicitly '
|
'All SparseTensor and RaggedTensor inputs must be explicitly '
|
||||||
|
Loading…
Reference in New Issue
Block a user