Port the conditional control flow conversion to the new operators.

PiperOrigin-RevId: 216548561
This commit is contained in:
Dan Moldovan 2018-10-10 10:05:49 -07:00 committed by TensorFlower Gardener
parent 9fe6fe02a1
commit e09ddb4290
6 changed files with 45 additions and 134 deletions

View File

@ -49,12 +49,23 @@ class ControlFlowTransformer(converter.Base):
def _create_cond_branch(self, body_name, aliased_orig_names,
aliased_new_names, body, returns):
if len(returns) == 1:
template = """
return retval
"""
return_stmt = templates.replace(template, retval=returns[0])
else:
template = """
return (retvals,)
"""
return_stmt = templates.replace(template, retvals=returns)
if aliased_orig_names:
template = """
def body_name():
aliased_new_names, = aliased_orig_names,
body
return (returns,)
return_stmt
"""
return templates.replace(
template,
@ -62,20 +73,20 @@ class ControlFlowTransformer(converter.Base):
body=body,
aliased_orig_names=aliased_orig_names,
aliased_new_names=aliased_new_names,
returns=returns)
return_stmt=return_stmt)
else:
template = """
def body_name():
body
return (returns,)
return_stmt
"""
return templates.replace(
template, body_name=body_name, body=body, returns=returns)
template, body_name=body_name, body=body, return_stmt=return_stmt)
def _create_cond_expr(self, results, test, body_name, orelse_name):
if results is not None:
template = """
results = ag__.utils.run_cond(test, body_name, orelse_name)
results = ag__.if_stmt(test, body_name, orelse_name)
"""
return templates.replace(
template,
@ -85,7 +96,7 @@ class ControlFlowTransformer(converter.Base):
orelse_name=orelse_name)
else:
template = """
ag__.utils.run_cond(test, body_name, orelse_name)
ag__.if_stmt(test, body_name, orelse_name)
"""
return templates.replace(
template, test=test, body_name=body_name, orelse_name=orelse_name)
@ -111,7 +122,7 @@ class ControlFlowTransformer(converter.Base):
elif s.is_composite():
# Special treatment for compound objects: if any of their owner entities
# are live, then they are outputs as well.
if any(owner in live_out for owner in s.owner_set):
if live_out & s.owner_set:
returned_from_cond.add(s)
need_alias_in_body = body_scope.modified & defined_in
@ -152,7 +163,6 @@ class ControlFlowTransformer(converter.Base):
returned_from_cond = tuple(returned_from_cond)
if returned_from_cond:
if len(returned_from_cond) == 1:
# TODO(mdan): Move this quirk into the operator implementation.
cond_results = returned_from_cond[0]
else:
cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
@ -171,8 +181,9 @@ class ControlFlowTransformer(converter.Base):
# actually has some return value as well.
cond_results = None
# TODO(mdan): This doesn't belong here; it's specific to the operator.
returned_from_body = templates.replace_as_expression('tf.constant(1)')
returned_from_orelse = templates.replace_as_expression('tf.constant(1)')
returned_from_body = (templates.replace_as_expression('tf.constant(1)'),)
returned_from_orelse = (
templates.replace_as_expression('tf.constant(1)'),)
body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)

View File

@ -80,20 +80,34 @@ class WhileLoopTest(test.TestCase):
class IfStmtTest(test.TestCase):
def test_tensor(self):
def test_if_stmt(cond):
return control_flow.if_stmt(
cond=cond,
body=lambda: 1,
orelse=lambda: -1)
def single_return_if_stmt(self, cond):
return control_flow.if_stmt(cond=cond, body=lambda: 1, orelse=lambda: -1)
def multi_return_if_stmt(self, cond):
return control_flow.if_stmt(
cond=cond, body=lambda: (1, 2), orelse=lambda: (-1, -2))
def test_tensor(self):
with self.cached_session() as sess:
self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True))))
self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False))))
t = self.single_return_if_stmt(constant_op.constant(True))
self.assertEqual(1, sess.run(t))
t = self.single_return_if_stmt(constant_op.constant(False))
self.assertEqual(-1, sess.run(t))
def test_python(self):
self.assertEqual(1, control_flow.if_stmt(True, lambda: 1, lambda: -1))
self.assertEqual(-1, control_flow.if_stmt(False, lambda: 1, lambda: -1))
self.assertEqual(1, self.single_return_if_stmt(True))
self.assertEqual(-1, self.single_return_if_stmt(False))
def test_tensor_multiple_returns(self):
with self.cached_session() as sess:
t = self.multi_return_if_stmt(constant_op.constant(True))
self.assertAllEqual([1, 2], sess.run(t))
t = self.multi_return_if_stmt(constant_op.constant(False))
self.assertAllEqual([-1, -2], sess.run(t))
def test_python_multiple_returns(self):
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
self.assertEqual((-1, -2), self.multi_return_if_stmt(False))
if __name__ == '__main__':

View File

@ -22,7 +22,6 @@ py_library(
"__init__.py",
"context_managers.py",
"misc.py",
"multiple_dispatch.py",
"py_func.py",
"tensor_list.py",
"tensors.py",
@ -61,16 +60,6 @@ py_test(
],
)
py_test(
name = "multiple_dispatch_test",
srcs = ["multiple_dispatch_test.py"],
srcs_version = "PY2AND3",
deps = [
":utils",
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "py_func_test",
srcs = ["py_func_test.py"],

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns
from tensorflow.python.autograph.utils.misc import alias_tensors
from tensorflow.python.autograph.utils.multiple_dispatch import run_cond
from tensorflow.python.autograph.utils.py_func import wrap_py_func
from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
from tensorflow.python.autograph.utils.testing import fake_tf

View File

@ -1,56 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for type-dependent behavior used in autograph-generated code."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.utils.type_check import is_tensor
from tensorflow.python.ops import control_flow_ops
def run_cond(condition, true_fn, false_fn):
"""Type-dependent functional conditional.
Args:
condition: A Tensor or Python bool.
true_fn: A Python callable implementing the true branch of the conditional.
false_fn: A Python callable implementing the false branch of the
conditional.
Returns:
result: The result of calling the appropriate branch. If condition is a
Tensor, tf.cond will be used. Otherwise, a standard Python if statement will
be ran.
"""
if is_tensor(condition):
return control_flow_ops.cond(condition, true_fn, false_fn)
else:
return py_cond(condition, true_fn, false_fn)
def py_cond(condition, true_fn, false_fn):
"""Functional version of Python's conditional."""
if condition:
results = true_fn()
else:
results = false_fn()
# The contract for the branch functions is to return tuples, but they should
# be collapsed to a single element when there is only one output.
if len(results) == 1:
return results[0]
return results

View File

@ -1,46 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for multiple_dispatch."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.utils import multiple_dispatch
from tensorflow.python.client.session import Session
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.platform import test
class MultipleDispatchTest(test.TestCase):
def test_run_cond_python(self):
true_fn = lambda: (2,)
false_fn = lambda: (3,)
self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2)
self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3)
def test_run_cond_tf(self):
true_fn = lambda: (constant(2),)
false_fn = lambda: (constant(3),)
with Session() as sess:
out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn)
self.assertEqual(sess.run(out), 2)
out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn)
self.assertEqual(sess.run(out), 3)
if __name__ == '__main__':
test.main()