Make tf.If work with ConcreteFunction.
PiperOrigin-RevId: 313860966 Change-Id: I1fccdaf06802511a7020a4045751cdd6b6821687
This commit is contained in:
parent
f90c649e28
commit
85396efcd3
tensorflow/python
@ -2214,6 +2214,21 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ops/functional_ops_test",
|
||||
srcs = ["ops/functional_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":client_testlib",
|
||||
":dtypes",
|
||||
":function",
|
||||
":functional_ops",
|
||||
":tensor_spec",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "function_test",
|
||||
size = "medium",
|
||||
|
@ -838,12 +838,13 @@ def If(cond, inputs, then_branch, else_branch, name=None):
|
||||
or else_branch(inputs).
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(then_branch, function._DefinedFunction):
|
||||
tlist = [_.type for _ in then_branch.definition.signature.output_arg]
|
||||
else:
|
||||
# We assume that `then_branch` is a ConcreteFunction here.
|
||||
tlist = nest.flatten(then_branch.output_dtypes)
|
||||
return gen_functional_ops._if(
|
||||
cond,
|
||||
inputs, [_.type for _ in then_branch.definition.signature.output_arg],
|
||||
then_branch,
|
||||
else_branch,
|
||||
name=name)
|
||||
cond, inputs, tlist, then_branch, else_branch, name=name)
|
||||
|
||||
|
||||
def Gradient(inputs, f, name=None):
|
||||
|
69
tensorflow/python/ops/functional_ops_test.py
Normal file
69
tensorflow/python/ops/functional_ops_test.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for functional operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FunctionalOpsTest(test.TestCase):
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testIfWithDefun(self):
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def Then(x):
|
||||
return x + 1
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def Else(x):
|
||||
return x - 1
|
||||
|
||||
with self.cached_session():
|
||||
inputs = [10.]
|
||||
result = self.evaluate(functional_ops.If(False, inputs, Then, Else))
|
||||
self.assertEqual([9.0], result)
|
||||
|
||||
def testIfWithFunction(self):
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec((), dtypes.float32)])
|
||||
def Then(x):
|
||||
return x + 1
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec((), dtypes.float32)])
|
||||
def Else(x):
|
||||
return x - 1
|
||||
|
||||
with self.cached_session():
|
||||
inputs = [10.]
|
||||
result = self.evaluate(
|
||||
functional_ops.If(False, inputs, Then.get_concrete_function(),
|
||||
Else.get_concrete_function()))
|
||||
self.assertEqual([9.0], result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user