Rollforward: Support CompositeTensor in functional If/While
Remove support for composite expansion in While due to IndexedSlices handling. PiperOrigin-RevId: 333413179 Change-Id: I9a3a1fab34b74f3741307c9ebdd42aac15f13bf7
This commit is contained in:
parent
c680c3a0b3
commit
faf6a15ff4
@ -2422,7 +2422,7 @@ py_library(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ops/functional_ops_test",
|
||||
name = "functional_ops_test",
|
||||
srcs = ["ops/functional_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
|
@ -838,14 +838,28 @@ def If(cond, inputs, then_branch, else_branch, name=None):
|
||||
or else_branch(inputs).
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
# Handle the Defun case until users have transitioned to tf.function. Note
|
||||
# that composites may need to be re-packed by the caller.
|
||||
if isinstance(then_branch, function._DefinedFunction):
|
||||
tlist = [_.type for _ in then_branch.definition.signature.output_arg]
|
||||
else:
|
||||
# We assume that `then_branch` is a ConcreteFunction here.
|
||||
tlist = nest.flatten(then_branch.output_dtypes)
|
||||
return gen_functional_ops._if(
|
||||
return gen_functional_ops._if(
|
||||
cond, inputs, tlist, then_branch, else_branch, name=name)
|
||||
|
||||
# We assume that `then_branch` is a ConcreteFunction here.
|
||||
then_out = then_branch.structured_outputs
|
||||
else_out = else_branch.structured_outputs
|
||||
|
||||
# Ensure then/else are the same type of composites to avoid an invalid call
|
||||
# to pack_sequence_as later on.
|
||||
nest.assert_same_structure(then_out, else_out, expand_composites=True)
|
||||
|
||||
tlist = nest.flatten(then_branch.output_dtypes)
|
||||
ret = gen_functional_ops._if(
|
||||
cond, inputs, tlist, then_branch, else_branch, name=name)
|
||||
|
||||
# Re-pack the outputs to restore any CompositeTensors
|
||||
return nest.pack_sequence_as(then_out, ret, expand_composites=True)
|
||||
|
||||
|
||||
def Gradient(inputs, f, name=None):
|
||||
r"""Computes the gradient function for function f via backpropagation.
|
||||
|
@ -21,26 +21,26 @@ from __future__ import print_function
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FunctionalOpsTest(test.TestCase):
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testIfWithDefun(self):
|
||||
# Defun should only be used in graph mode
|
||||
with ops.Graph().as_default():
|
||||
@function.Defun(dtypes.float32)
|
||||
def Then(x):
|
||||
return x + 1
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def Then(x):
|
||||
return x + 1
|
||||
@function.Defun(dtypes.float32)
|
||||
def Else(x):
|
||||
return x - 1
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def Else(x):
|
||||
return x - 1
|
||||
|
||||
with self.cached_session():
|
||||
inputs = [10.]
|
||||
result = self.evaluate(functional_ops.If(False, inputs, Then, Else))
|
||||
self.assertEqual([9.0], result)
|
||||
@ -57,12 +57,29 @@ class FunctionalOpsTest(test.TestCase):
|
||||
def Else(x):
|
||||
return x - 1
|
||||
|
||||
with self.cached_session():
|
||||
inputs = [10.]
|
||||
result = self.evaluate(
|
||||
functional_ops.If(False, inputs, Then.get_concrete_function(),
|
||||
Else.get_concrete_function()))
|
||||
self.assertEqual([9.0], result)
|
||||
inputs = [10.]
|
||||
then_cf = Then.get_concrete_function()
|
||||
else_cf = Else.get_concrete_function()
|
||||
result = self.evaluate(functional_ops.If(False, inputs, then_cf, else_cf))
|
||||
self.assertEqual([9.0], result)
|
||||
|
||||
def testIfWithFunctionComposite(self):
|
||||
|
||||
signature = [tensor_spec.TensorSpec([], dtypes.float32)]
|
||||
@def_function.function(input_signature=signature)
|
||||
def Then(x):
|
||||
return sparse_tensor.SparseTensor([[0]], [x + 1], [1])
|
||||
|
||||
@def_function.function(input_signature=signature)
|
||||
def Else(x):
|
||||
return sparse_tensor.SparseTensor([[0]], [x - 1], [1])
|
||||
|
||||
inputs = [10.]
|
||||
then_cf = Then.get_concrete_function()
|
||||
else_cf = Else.get_concrete_function()
|
||||
result = functional_ops.If(False, inputs, then_cf, else_cf)
|
||||
self.assertIsInstance(result, sparse_tensor.SparseTensor)
|
||||
self.assertAllEqual([9.0], result.values)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user