Rollforward: Support CompositeTensor in functional If/While

Remove support for composite expansion in While due to IndexedSlices handling.

PiperOrigin-RevId: 333413179
Change-Id: I9a3a1fab34b74f3741307c9ebdd42aac15f13bf7
This commit is contained in:
Gaurav Jain 2020-09-23 18:05:27 -07:00 committed by TensorFlower Gardener
parent c680c3a0b3
commit faf6a15ff4
3 changed files with 52 additions and 21 deletions

View File

@ -2422,7 +2422,7 @@ py_library(
)
py_test(
name = "ops/functional_ops_test",
name = "functional_ops_test",
srcs = ["ops/functional_ops_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",

View File

@ -838,14 +838,28 @@ def If(cond, inputs, then_branch, else_branch, name=None):
or else_branch(inputs).
"""
# pylint: disable=protected-access
# Handle the Defun case until users have transitioned to tf.function. Note
# that composites may need to be re-packed by the caller.
if isinstance(then_branch, function._DefinedFunction):
tlist = [_.type for _ in then_branch.definition.signature.output_arg]
else:
# We assume that `then_branch` is a ConcreteFunction here.
tlist = nest.flatten(then_branch.output_dtypes)
return gen_functional_ops._if(
return gen_functional_ops._if(
cond, inputs, tlist, then_branch, else_branch, name=name)
# We assume that `then_branch` is a ConcreteFunction here.
then_out = then_branch.structured_outputs
else_out = else_branch.structured_outputs
# Ensure then/else are the same type of composites to avoid an invalid call
# to pack_sequence_as later on.
nest.assert_same_structure(then_out, else_out, expand_composites=True)
tlist = nest.flatten(then_branch.output_dtypes)
ret = gen_functional_ops._if(
cond, inputs, tlist, then_branch, else_branch, name=name)
# Re-pack the outputs to restore any CompositeTensors
return nest.pack_sequence_as(then_out, ret, expand_composites=True)
def Gradient(inputs, f, name=None):
r"""Computes the gradient function for function f via backpropagation.

View File

@ -21,26 +21,26 @@ from __future__ import print_function
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import functional_ops
from tensorflow.python.platform import test
class FunctionalOpsTest(test.TestCase):
@test_util.deprecated_graph_mode_only
def testIfWithDefun(self):
# Defun should only be used in graph mode
with ops.Graph().as_default():
@function.Defun(dtypes.float32)
def Then(x):
return x + 1
@function.Defun(dtypes.float32)
def Then(x):
return x + 1
@function.Defun(dtypes.float32)
def Else(x):
return x - 1
@function.Defun(dtypes.float32)
def Else(x):
return x - 1
with self.cached_session():
inputs = [10.]
result = self.evaluate(functional_ops.If(False, inputs, Then, Else))
self.assertEqual([9.0], result)
@ -57,12 +57,29 @@ class FunctionalOpsTest(test.TestCase):
def Else(x):
return x - 1
with self.cached_session():
inputs = [10.]
result = self.evaluate(
functional_ops.If(False, inputs, Then.get_concrete_function(),
Else.get_concrete_function()))
self.assertEqual([9.0], result)
inputs = [10.]
then_cf = Then.get_concrete_function()
else_cf = Else.get_concrete_function()
result = self.evaluate(functional_ops.If(False, inputs, then_cf, else_cf))
self.assertEqual([9.0], result)
def testIfWithFunctionComposite(self):
signature = [tensor_spec.TensorSpec([], dtypes.float32)]
@def_function.function(input_signature=signature)
def Then(x):
return sparse_tensor.SparseTensor([[0]], [x + 1], [1])
@def_function.function(input_signature=signature)
def Else(x):
return sparse_tensor.SparseTensor([[0]], [x - 1], [1])
inputs = [10.]
then_cf = Then.get_concrete_function()
else_cf = Else.get_concrete_function()
result = functional_ops.If(False, inputs, then_cf, else_cf)
self.assertIsInstance(result, sparse_tensor.SparseTensor)
self.assertAllEqual([9.0], result.values)
if __name__ == '__main__':