From faf6a15ff47efb8308166dcd0bfd2be206f1c4ca Mon Sep 17 00:00:00 2001 From: Gaurav Jain <gjn@google.com> Date: Wed, 23 Sep 2020 18:05:27 -0700 Subject: [PATCH] Rollforward: Support CompositeTensor in functional If/While Remove support for composite expansion in While due to IndexedSlices handling. PiperOrigin-RevId: 333413179 Change-Id: I9a3a1fab34b74f3741307c9ebdd42aac15f13bf7 --- tensorflow/python/BUILD | 2 +- tensorflow/python/ops/functional_ops.py | 22 +++++++-- tensorflow/python/ops/functional_ops_test.py | 49 +++++++++++++------- 3 files changed, 52 insertions(+), 21 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index adcb1138707..81bf5d95634 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2422,7 +2422,7 @@ py_library( ) py_test( - name = "ops/functional_ops_test", + name = "functional_ops_test", srcs = ["ops/functional_ops_test.py"], python_version = "PY3", srcs_version = "PY2AND3", diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 6e285d6681d..b51d1baa6c0 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -838,14 +838,28 @@ def If(cond, inputs, then_branch, else_branch, name=None): or else_branch(inputs). """ # pylint: disable=protected-access + # Handle the Defun case until users have transitioned to tf.function. Note + # that composites may need to be re-packed by the caller. if isinstance(then_branch, function._DefinedFunction): tlist = [_.type for _ in then_branch.definition.signature.output_arg] - else: - # We assume that `then_branch` is a ConcreteFunction here. - tlist = nest.flatten(then_branch.output_dtypes) - return gen_functional_ops._if( + return gen_functional_ops._if( + cond, inputs, tlist, then_branch, else_branch, name=name) + + # We assume that `then_branch` is a ConcreteFunction here. + then_out = then_branch.structured_outputs + else_out = else_branch.structured_outputs + + # Ensure then/else are the same type of composites to avoid an invalid call + # to pack_sequence_as later on. + nest.assert_same_structure(then_out, else_out, expand_composites=True) + + tlist = nest.flatten(then_branch.output_dtypes) + ret = gen_functional_ops._if( cond, inputs, tlist, then_branch, else_branch, name=name) + # Re-pack the outputs to restore any CompositeTensors + return nest.pack_sequence_as(then_out, ret, expand_composites=True) + def Gradient(inputs, f, name=None): r"""Computes the gradient function for function f via backpropagation. diff --git a/tensorflow/python/ops/functional_ops_test.py b/tensorflow/python/ops/functional_ops_test.py index 7e3bc631c44..92e97f63ecd 100644 --- a/tensorflow/python/ops/functional_ops_test.py +++ b/tensorflow/python/ops/functional_ops_test.py @@ -21,26 +21,26 @@ from __future__ import print_function from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec -from tensorflow.python.framework import test_util from tensorflow.python.ops import functional_ops from tensorflow.python.platform import test class FunctionalOpsTest(test.TestCase): - @test_util.deprecated_graph_mode_only def testIfWithDefun(self): + # Defun should only be used in graph mode + with ops.Graph().as_default(): + @function.Defun(dtypes.float32) + def Then(x): + return x + 1 - @function.Defun(dtypes.float32) - def Then(x): - return x + 1 + @function.Defun(dtypes.float32) + def Else(x): + return x - 1 - @function.Defun(dtypes.float32) - def Else(x): - return x - 1 - - with self.cached_session(): inputs = [10.] result = self.evaluate(functional_ops.If(False, inputs, Then, Else)) self.assertEqual([9.0], result) @@ -57,12 +57,29 @@ class FunctionalOpsTest(test.TestCase): def Else(x): return x - 1 - with self.cached_session(): - inputs = [10.] - result = self.evaluate( - functional_ops.If(False, inputs, Then.get_concrete_function(), - Else.get_concrete_function())) - self.assertEqual([9.0], result) + inputs = [10.] + then_cf = Then.get_concrete_function() + else_cf = Else.get_concrete_function() + result = self.evaluate(functional_ops.If(False, inputs, then_cf, else_cf)) + self.assertEqual([9.0], result) + + def testIfWithFunctionComposite(self): + + signature = [tensor_spec.TensorSpec([], dtypes.float32)] + @def_function.function(input_signature=signature) + def Then(x): + return sparse_tensor.SparseTensor([[0]], [x + 1], [1]) + + @def_function.function(input_signature=signature) + def Else(x): + return sparse_tensor.SparseTensor([[0]], [x - 1], [1]) + + inputs = [10.] + then_cf = Then.get_concrete_function() + else_cf = Else.get_concrete_function() + result = functional_ops.If(False, inputs, then_cf, else_cf) + self.assertIsInstance(result, sparse_tensor.SparseTensor) + self.assertAllEqual([9.0], result.values) if __name__ == '__main__':