Port the conditional control flow conversion to the new operators.
PiperOrigin-RevId: 216548561
This commit is contained in:
parent
9fe6fe02a1
commit
e09ddb4290
tensorflow/python/autograph
converters
operators
utils
@ -49,12 +49,23 @@ class ControlFlowTransformer(converter.Base):
|
||||
|
||||
def _create_cond_branch(self, body_name, aliased_orig_names,
|
||||
aliased_new_names, body, returns):
|
||||
if len(returns) == 1:
|
||||
template = """
|
||||
return retval
|
||||
"""
|
||||
return_stmt = templates.replace(template, retval=returns[0])
|
||||
else:
|
||||
template = """
|
||||
return (retvals,)
|
||||
"""
|
||||
return_stmt = templates.replace(template, retvals=returns)
|
||||
|
||||
if aliased_orig_names:
|
||||
template = """
|
||||
def body_name():
|
||||
aliased_new_names, = aliased_orig_names,
|
||||
body
|
||||
return (returns,)
|
||||
return_stmt
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
@ -62,20 +73,20 @@ class ControlFlowTransformer(converter.Base):
|
||||
body=body,
|
||||
aliased_orig_names=aliased_orig_names,
|
||||
aliased_new_names=aliased_new_names,
|
||||
returns=returns)
|
||||
return_stmt=return_stmt)
|
||||
else:
|
||||
template = """
|
||||
def body_name():
|
||||
body
|
||||
return (returns,)
|
||||
return_stmt
|
||||
"""
|
||||
return templates.replace(
|
||||
template, body_name=body_name, body=body, returns=returns)
|
||||
template, body_name=body_name, body=body, return_stmt=return_stmt)
|
||||
|
||||
def _create_cond_expr(self, results, test, body_name, orelse_name):
|
||||
if results is not None:
|
||||
template = """
|
||||
results = ag__.utils.run_cond(test, body_name, orelse_name)
|
||||
results = ag__.if_stmt(test, body_name, orelse_name)
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
@ -85,7 +96,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
orelse_name=orelse_name)
|
||||
else:
|
||||
template = """
|
||||
ag__.utils.run_cond(test, body_name, orelse_name)
|
||||
ag__.if_stmt(test, body_name, orelse_name)
|
||||
"""
|
||||
return templates.replace(
|
||||
template, test=test, body_name=body_name, orelse_name=orelse_name)
|
||||
@ -111,7 +122,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
elif s.is_composite():
|
||||
# Special treatment for compound objects: if any of their owner entities
|
||||
# are live, then they are outputs as well.
|
||||
if any(owner in live_out for owner in s.owner_set):
|
||||
if live_out & s.owner_set:
|
||||
returned_from_cond.add(s)
|
||||
|
||||
need_alias_in_body = body_scope.modified & defined_in
|
||||
@ -152,7 +163,6 @@ class ControlFlowTransformer(converter.Base):
|
||||
returned_from_cond = tuple(returned_from_cond)
|
||||
if returned_from_cond:
|
||||
if len(returned_from_cond) == 1:
|
||||
# TODO(mdan): Move this quirk into the operator implementation.
|
||||
cond_results = returned_from_cond[0]
|
||||
else:
|
||||
cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
|
||||
@ -171,8 +181,9 @@ class ControlFlowTransformer(converter.Base):
|
||||
# actually has some return value as well.
|
||||
cond_results = None
|
||||
# TODO(mdan): This doesn't belong here; it's specific to the operator.
|
||||
returned_from_body = templates.replace_as_expression('tf.constant(1)')
|
||||
returned_from_orelse = templates.replace_as_expression('tf.constant(1)')
|
||||
returned_from_body = (templates.replace_as_expression('tf.constant(1)'),)
|
||||
returned_from_orelse = (
|
||||
templates.replace_as_expression('tf.constant(1)'),)
|
||||
|
||||
body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
|
||||
orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
|
||||
|
@ -80,20 +80,34 @@ class WhileLoopTest(test.TestCase):
|
||||
|
||||
class IfStmtTest(test.TestCase):
|
||||
|
||||
def test_tensor(self):
|
||||
def test_if_stmt(cond):
|
||||
return control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: 1,
|
||||
orelse=lambda: -1)
|
||||
def single_return_if_stmt(self, cond):
|
||||
return control_flow.if_stmt(cond=cond, body=lambda: 1, orelse=lambda: -1)
|
||||
|
||||
def multi_return_if_stmt(self, cond):
|
||||
return control_flow.if_stmt(
|
||||
cond=cond, body=lambda: (1, 2), orelse=lambda: (-1, -2))
|
||||
|
||||
def test_tensor(self):
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True))))
|
||||
self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False))))
|
||||
t = self.single_return_if_stmt(constant_op.constant(True))
|
||||
self.assertEqual(1, sess.run(t))
|
||||
t = self.single_return_if_stmt(constant_op.constant(False))
|
||||
self.assertEqual(-1, sess.run(t))
|
||||
|
||||
def test_python(self):
|
||||
self.assertEqual(1, control_flow.if_stmt(True, lambda: 1, lambda: -1))
|
||||
self.assertEqual(-1, control_flow.if_stmt(False, lambda: 1, lambda: -1))
|
||||
self.assertEqual(1, self.single_return_if_stmt(True))
|
||||
self.assertEqual(-1, self.single_return_if_stmt(False))
|
||||
|
||||
def test_tensor_multiple_returns(self):
|
||||
with self.cached_session() as sess:
|
||||
t = self.multi_return_if_stmt(constant_op.constant(True))
|
||||
self.assertAllEqual([1, 2], sess.run(t))
|
||||
t = self.multi_return_if_stmt(constant_op.constant(False))
|
||||
self.assertAllEqual([-1, -2], sess.run(t))
|
||||
|
||||
def test_python_multiple_returns(self):
|
||||
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
|
||||
self.assertEqual((-1, -2), self.multi_return_if_stmt(False))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -22,7 +22,6 @@ py_library(
|
||||
"__init__.py",
|
||||
"context_managers.py",
|
||||
"misc.py",
|
||||
"multiple_dispatch.py",
|
||||
"py_func.py",
|
||||
"tensor_list.py",
|
||||
"tensors.py",
|
||||
@ -61,16 +60,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "multiple_dispatch_test",
|
||||
srcs = ["multiple_dispatch_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "py_func_test",
|
||||
srcs = ["py_func_test.py"],
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns
|
||||
from tensorflow.python.autograph.utils.misc import alias_tensors
|
||||
from tensorflow.python.autograph.utils.multiple_dispatch import run_cond
|
||||
from tensorflow.python.autograph.utils.py_func import wrap_py_func
|
||||
from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
|
||||
from tensorflow.python.autograph.utils.testing import fake_tf
|
||||
|
@ -1,56 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for type-dependent behavior used in autograph-generated code."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.utils.type_check import is_tensor
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
|
||||
|
||||
def run_cond(condition, true_fn, false_fn):
|
||||
"""Type-dependent functional conditional.
|
||||
|
||||
Args:
|
||||
condition: A Tensor or Python bool.
|
||||
true_fn: A Python callable implementing the true branch of the conditional.
|
||||
false_fn: A Python callable implementing the false branch of the
|
||||
conditional.
|
||||
|
||||
Returns:
|
||||
result: The result of calling the appropriate branch. If condition is a
|
||||
Tensor, tf.cond will be used. Otherwise, a standard Python if statement will
|
||||
be ran.
|
||||
"""
|
||||
if is_tensor(condition):
|
||||
return control_flow_ops.cond(condition, true_fn, false_fn)
|
||||
else:
|
||||
return py_cond(condition, true_fn, false_fn)
|
||||
|
||||
|
||||
def py_cond(condition, true_fn, false_fn):
|
||||
"""Functional version of Python's conditional."""
|
||||
if condition:
|
||||
results = true_fn()
|
||||
else:
|
||||
results = false_fn()
|
||||
|
||||
# The contract for the branch functions is to return tuples, but they should
|
||||
# be collapsed to a single element when there is only one output.
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
return results
|
@ -1,46 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for multiple_dispatch."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.utils import multiple_dispatch
|
||||
from tensorflow.python.client.session import Session
|
||||
from tensorflow.python.framework.constant_op import constant
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MultipleDispatchTest(test.TestCase):
|
||||
|
||||
def test_run_cond_python(self):
|
||||
true_fn = lambda: (2,)
|
||||
false_fn = lambda: (3,)
|
||||
self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2)
|
||||
self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3)
|
||||
|
||||
def test_run_cond_tf(self):
|
||||
true_fn = lambda: (constant(2),)
|
||||
false_fn = lambda: (constant(3),)
|
||||
with Session() as sess:
|
||||
out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn)
|
||||
self.assertEqual(sess.run(out), 2)
|
||||
out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn)
|
||||
self.assertEqual(sess.run(out), 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user