From faf6a15ff47efb8308166dcd0bfd2be206f1c4ca Mon Sep 17 00:00:00 2001
From: Gaurav Jain <gjn@google.com>
Date: Wed, 23 Sep 2020 18:05:27 -0700
Subject: [PATCH] Rollforward: Support CompositeTensor in functional If/While

Remove support for composite expansion in While due to IndexedSlices handling.

PiperOrigin-RevId: 333413179
Change-Id: I9a3a1fab34b74f3741307c9ebdd42aac15f13bf7
---
 tensorflow/python/BUILD                      |  2 +-
 tensorflow/python/ops/functional_ops.py      | 22 +++++++--
 tensorflow/python/ops/functional_ops_test.py | 49 +++++++++++++-------
 3 files changed, 52 insertions(+), 21 deletions(-)

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index adcb1138707..81bf5d95634 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2422,7 +2422,7 @@ py_library(
 )
 
 py_test(
-    name = "ops/functional_ops_test",
+    name = "functional_ops_test",
     srcs = ["ops/functional_ops_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 6e285d6681d..b51d1baa6c0 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -838,14 +838,28 @@ def If(cond, inputs, then_branch, else_branch, name=None):
     or else_branch(inputs).
   """
   # pylint: disable=protected-access
+  # Handle the Defun case until users have transitioned to tf.function. Note
+  # that composites may need to be re-packed by the caller.
   if isinstance(then_branch, function._DefinedFunction):
     tlist = [_.type for _ in then_branch.definition.signature.output_arg]
-  else:
-    # We assume that `then_branch` is a ConcreteFunction here.
-    tlist = nest.flatten(then_branch.output_dtypes)
-  return gen_functional_ops._if(
+    return gen_functional_ops._if(
+        cond, inputs, tlist, then_branch, else_branch, name=name)
+
+  # We assume that `then_branch` is a ConcreteFunction here.
+  then_out = then_branch.structured_outputs
+  else_out = else_branch.structured_outputs
+
+  # Ensure then/else are the same type of composites to avoid an invalid call
+  # to pack_sequence_as later on.
+  nest.assert_same_structure(then_out, else_out, expand_composites=True)
+
+  tlist = nest.flatten(then_branch.output_dtypes)
+  ret = gen_functional_ops._if(
       cond, inputs, tlist, then_branch, else_branch, name=name)
 
+  # Re-pack the outputs to restore any CompositeTensors
+  return nest.pack_sequence_as(then_out, ret, expand_composites=True)
+
 
 def Gradient(inputs, f, name=None):
   r"""Computes the gradient function for function f via backpropagation.
diff --git a/tensorflow/python/ops/functional_ops_test.py b/tensorflow/python/ops/functional_ops_test.py
index 7e3bc631c44..92e97f63ecd 100644
--- a/tensorflow/python/ops/functional_ops_test.py
+++ b/tensorflow/python/ops/functional_ops_test.py
@@ -21,26 +21,26 @@ from __future__ import print_function
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_spec
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import functional_ops
 from tensorflow.python.platform import test
 
 
 class FunctionalOpsTest(test.TestCase):
 
-  @test_util.deprecated_graph_mode_only
   def testIfWithDefun(self):
+    # Defun should only be used in graph mode
+    with ops.Graph().as_default():
+      @function.Defun(dtypes.float32)
+      def Then(x):
+        return x + 1
 
-    @function.Defun(dtypes.float32)
-    def Then(x):
-      return x + 1
+      @function.Defun(dtypes.float32)
+      def Else(x):
+        return x - 1
 
-    @function.Defun(dtypes.float32)
-    def Else(x):
-      return x - 1
-
-    with self.cached_session():
       inputs = [10.]
       result = self.evaluate(functional_ops.If(False, inputs, Then, Else))
       self.assertEqual([9.0], result)
@@ -57,12 +57,29 @@ class FunctionalOpsTest(test.TestCase):
     def Else(x):
       return x - 1
 
-    with self.cached_session():
-      inputs = [10.]
-      result = self.evaluate(
-          functional_ops.If(False, inputs, Then.get_concrete_function(),
-                            Else.get_concrete_function()))
-      self.assertEqual([9.0], result)
+    inputs = [10.]
+    then_cf = Then.get_concrete_function()
+    else_cf = Else.get_concrete_function()
+    result = self.evaluate(functional_ops.If(False, inputs, then_cf, else_cf))
+    self.assertEqual([9.0], result)
+
+  def testIfWithFunctionComposite(self):
+
+    signature = [tensor_spec.TensorSpec([], dtypes.float32)]
+    @def_function.function(input_signature=signature)
+    def Then(x):
+      return sparse_tensor.SparseTensor([[0]], [x + 1], [1])
+
+    @def_function.function(input_signature=signature)
+    def Else(x):
+      return sparse_tensor.SparseTensor([[0]], [x - 1], [1])
+
+    inputs = [10.]
+    then_cf = Then.get_concrete_function()
+    else_cf = Else.get_concrete_function()
+    result = functional_ops.If(False, inputs, then_cf, else_cf)
+    self.assertIsInstance(result, sparse_tensor.SparseTensor)
+    self.assertAllEqual([9.0], result.values)
 
 
 if __name__ == '__main__':