From b8fcda3cd1fe0a69b0957b75dbd7738c598d1749 Mon Sep 17 00:00:00 2001
From: Tomer Kaftan <kaftan@google.com>
Date: Mon, 26 Oct 2020 14:38:11 -0700
Subject: [PATCH] Fork is_composite_or_commposite_value into Keras to split a
 dependency on private symbols

PiperOrigin-RevId: 339120462
Change-Id: I8c48668daf0daf330b4d34bc58b4a20f1a9de67c
---
 .../python/keras/engine/training_utils_v1.py  | 28 +++++++++++++------
 tensorflow/python/keras/engine/training_v1.py |  3 +-
 2 files changed, 20 insertions(+), 11 deletions(-)

diff --git a/tensorflow/python/keras/engine/training_utils_v1.py b/tensorflow/python/keras/engine/training_utils_v1.py
index bc83b67fdea..fff7fd1fea5 100644
--- a/tensorflow/python/keras/engine/training_utils_v1.py
+++ b/tensorflow/python/keras/engine/training_utils_v1.py
@@ -36,11 +36,13 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.eager import context
+from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import composite_tensor_utils
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import smart_cond
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.keras import backend as K
@@ -55,11 +57,22 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.ops.ragged import ragged_tensor_value
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
 from tensorflow.python.util.compat import collections_abc
 
 
+def is_composite_or_composite_value(tensor):
+  """Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
+  # TODO(b/125094323): This should be isinstance(CompositeTensor) or
+  # isinstance(CompositeTensorValue) once we support that.
+  return isinstance(
+      tensor,
+      (composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue,
+       ragged_tensor_value.RaggedTensorValue))
+
+
 @six.add_metaclass(abc.ABCMeta)
 class Aggregator(object):
   """Abstract base class used to aggregate batch-level outputs of a loop.
@@ -156,8 +169,7 @@ class ConcatAggregator(Aggregator):
         use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
 
   def create(self, batch_element):
-    self.composite = composite_tensor_utils.is_composite_or_composite_value(
-        batch_element)
+    self.composite = is_composite_or_composite_value(batch_element)
 
   def aggregate(self, batch_element, batch_start=None, batch_end=None):
 
@@ -313,12 +325,11 @@ class OutputsAggregator(Aggregator):
     # SparseTensorValue is a named tuple which nest will flatten, so we need
     # to guard it to properly handle the structure.
     self._structure = nest.get_traverse_shallow_structure(
-        lambda x: not composite_tensor_utils.is_composite_or_composite_value(x),
-        batch_outs)
+        lambda x: not is_composite_or_composite_value(x), batch_outs)
     batch_outs = nest.flatten_up_to(self._structure, batch_outs)
 
     for batch_element in batch_outs:
-      if composite_tensor_utils.is_composite_or_composite_value(batch_element):
+      if is_composite_or_composite_value(batch_element):
         # If the output is not a ndarray, it will be either a composite tensor
         # or a composite tensor's Value object. In either case, we can't
         # allocate an array to hold the object - we'll handle it later.
@@ -399,7 +410,7 @@ def standardize_single_array(x, expected_shape=None):
   if x is None:
     return None
 
-  if composite_tensor_utils.is_composite_or_composite_value(x):
+  if is_composite_or_composite_value(x):
     return x
 
   if isinstance(x, int):
@@ -517,7 +528,7 @@ def standardize_input_data(data,
           if not tensorshape:
             continue
           data_shape = tuple(tensorshape.as_list())
-        elif composite_tensor_utils.is_composite_or_composite_value(data[i]):
+        elif is_composite_or_composite_value(data[i]):
           tensorshape = composite_tensor_utils.get_shape(data[i])
           data_shape = tuple(tensorshape.as_list())
         else:
@@ -610,8 +621,7 @@ def check_array_lengths(inputs, targets, weights=None):
   """
 
   def is_tensor_or_composite_tensor(x):
-    return tensor_util.is_tensor(
-        x) or composite_tensor_utils.is_composite_or_composite_value(x)
+    return tensor_util.is_tensor(x) or is_composite_or_composite_value(x)
 
   def set_of_lengths(x):
     # Returns a set with the variation between
diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py
index 2cbf24bb9ce..5df44699f73 100644
--- a/tensorflow/python/keras/engine/training_v1.py
+++ b/tensorflow/python/keras/engine/training_v1.py
@@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import parameter_server_strategy
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
-from tensorflow.python.framework import composite_tensor_utils
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
@@ -2495,7 +2494,7 @@ class Model(training_lib.Model):
     # users should explicitly add composite tensor inputs to their subclassed
     # models.
     for input_tensor in processed_inputs:
-      if composite_tensor_utils.is_composite_or_composite_value(input_tensor):
+      if training_utils_v1.is_composite_or_composite_value(input_tensor):
         # TODO(b/132691975): Document subclass-model CT input handling.
         raise ValueError(
             'All SparseTensor and RaggedTensor inputs must be explicitly '