From b8fcda3cd1fe0a69b0957b75dbd7738c598d1749 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan <kaftan@google.com> Date: Mon, 26 Oct 2020 14:38:11 -0700 Subject: [PATCH] Fork is_composite_or_commposite_value into Keras to split a dependency on private symbols PiperOrigin-RevId: 339120462 Change-Id: I8c48668daf0daf330b4d34bc58b4a20f1a9de67c --- .../python/keras/engine/training_utils_v1.py | 28 +++++++++++++------ tensorflow/python/keras/engine/training_v1.py | 3 +- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/keras/engine/training_utils_v1.py b/tensorflow/python/keras/engine/training_utils_v1.py index bc83b67fdea..fff7fd1fea5 100644 --- a/tensorflow/python/keras/engine/training_utils_v1.py +++ b/tensorflow/python/keras/engine/training_utils_v1.py @@ -36,11 +36,13 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K @@ -55,11 +57,22 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util.compat import collections_abc +def is_composite_or_composite_value(tensor): + """Returns true if 'tensor' is a CompositeTensor or a CT Value object.""" + # TODO(b/125094323): This should be isinstance(CompositeTensor) or + # isinstance(CompositeTensorValue) once we support that. + return isinstance( + tensor, + (composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue, + ragged_tensor_value.RaggedTensorValue)) + + @six.add_metaclass(abc.ABCMeta) class Aggregator(object): """Abstract base class used to aggregate batch-level outputs of a loop. @@ -156,8 +169,7 @@ class ConcatAggregator(Aggregator): use_steps=True, num_samples=None, steps=None, batch_size=batch_size) def create(self, batch_element): - self.composite = composite_tensor_utils.is_composite_or_composite_value( - batch_element) + self.composite = is_composite_or_composite_value(batch_element) def aggregate(self, batch_element, batch_start=None, batch_end=None): @@ -313,12 +325,11 @@ class OutputsAggregator(Aggregator): # SparseTensorValue is a named tuple which nest will flatten, so we need # to guard it to properly handle the structure. self._structure = nest.get_traverse_shallow_structure( - lambda x: not composite_tensor_utils.is_composite_or_composite_value(x), - batch_outs) + lambda x: not is_composite_or_composite_value(x), batch_outs) batch_outs = nest.flatten_up_to(self._structure, batch_outs) for batch_element in batch_outs: - if composite_tensor_utils.is_composite_or_composite_value(batch_element): + if is_composite_or_composite_value(batch_element): # If the output is not a ndarray, it will be either a composite tensor # or a composite tensor's Value object. In either case, we can't # allocate an array to hold the object - we'll handle it later. @@ -399,7 +410,7 @@ def standardize_single_array(x, expected_shape=None): if x is None: return None - if composite_tensor_utils.is_composite_or_composite_value(x): + if is_composite_or_composite_value(x): return x if isinstance(x, int): @@ -517,7 +528,7 @@ def standardize_input_data(data, if not tensorshape: continue data_shape = tuple(tensorshape.as_list()) - elif composite_tensor_utils.is_composite_or_composite_value(data[i]): + elif is_composite_or_composite_value(data[i]): tensorshape = composite_tensor_utils.get_shape(data[i]) data_shape = tuple(tensorshape.as_list()) else: @@ -610,8 +621,7 @@ def check_array_lengths(inputs, targets, weights=None): """ def is_tensor_or_composite_tensor(x): - return tensor_util.is_tensor( - x) or composite_tensor_utils.is_composite_or_composite_value(x) + return tensor_util.is_tensor(x) or is_composite_or_composite_value(x) def set_of_lengths(x): # Returns a set with the variation between diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 2cbf24bb9ce..5df44699f73 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.eager import context from tensorflow.python.eager import def_function -from tensorflow.python.framework import composite_tensor_utils from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -2495,7 +2494,7 @@ class Model(training_lib.Model): # users should explicitly add composite tensor inputs to their subclassed # models. for input_tensor in processed_inputs: - if composite_tensor_utils.is_composite_or_composite_value(input_tensor): + if training_utils_v1.is_composite_or_composite_value(input_tensor): # TODO(b/132691975): Document subclass-model CT input handling. raise ValueError( 'All SparseTensor and RaggedTensor inputs must be explicitly '