Use a smart constant instead tf.constant -- one that matches the condition value, in the sense that it will either be a Tensor or a Python value.

This change also required adjusting the way side_effect_guards works, which now will gate only the live symbols. A third change adds the live symbols to expression nodes.

Motivation: pure-side-effects control flow statements are being augmented with a dummy return value (they just return a single scalar). This provides a return value that can be used as control dependency by side_effect_guards. Such dependency necessarily needs to be a tensor, but we don't want to return a tensor unless the if statement is staged to a cond. This CL makes that return value consistent with the statement itself.
PiperOrigin-RevId: 220506051
This commit is contained in:
Dan Moldovan 2018-11-07 12:07:26 -08:00 committed by TensorFlower Gardener
parent 5b373961c6
commit 81f1d6751b
8 changed files with 63 additions and 21 deletions

View File

@ -160,6 +160,10 @@ class ControlFlowTransformer(converter.Base):
node_body = ast_util.rename_symbols(node.body, alias_body_map)
node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
returned_from_cond = tuple(returned_from_cond)
if returned_from_cond:
if len(returned_from_cond) == 1:
@ -181,13 +185,14 @@ class ControlFlowTransformer(converter.Base):
# actually has some return value as well.
cond_results = None
# TODO(mdan): This doesn't belong here; it's specific to the operator.
returned_from_body = (templates.replace_as_expression('tf.constant(1)'),)
returned_from_orelse = (
templates.replace_as_expression('tf.constant(1)'),)
body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
returned_from_body = (templates.replace_as_expression(
'ag__.match_staging_level(1, cond_var_name)',
cond_var_name=cond_var_name),)
returned_from_orelse = (templates.replace_as_expression(
'ag__.match_staging_level(1, cond_var_name)',
cond_var_name=cond_var_name),)
cond_assign = self.create_assignment(cond_var_name, node.test)
body_def = self._create_cond_branch(
body_name,
aliased_orig_names=aliased_body_orig_names,
@ -200,10 +205,10 @@ class ControlFlowTransformer(converter.Base):
aliased_new_names=aliased_orelse_new_names,
body=node_orelse,
returns=returned_from_orelse)
cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
orelse_name)
return body_def + orelse_def + cond_expr
return cond_assign + body_def + orelse_def + cond_expr
def _get_loop_state(self, node):
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

View File

@ -122,11 +122,12 @@ class SideEffectGuardTransformer(converter.Base):
# possible, gate all remaining statements (and that may fail too, see
# _visit_and_reindent.
args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
# NOTE: We can't guard object attributes because they may not be writable.
# In addition, avoid renaming well-known names.
# TODO(mdan): Move these names into config.
unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
guarded_args = tuple(s for s in args_scope.read
unguarded_names = (qual_names.QN('self'), qual_names.QN('ag__'))
guarded_args = tuple(s for s in live_out
if not s.is_composite() and s not in unguarded_names)
# TODO(mdan): Include all arguments which depended on guarded_args too.

View File

@ -30,6 +30,7 @@ from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import errors
from tensorflow.python.autograph.core import function_wrapping
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
@ -103,6 +104,7 @@ class TestCase(test.TestCase):
fake_ag = self.make_fake_mod('fake_ag', converted_call,
converter.ConversionOptions)
fake_ag.__dict__.update(operators.__dict__)
fake_ag.__dict__.update(special_functions.__dict__)
fake_ag.__dict__['utils'] = utils
fake_ag.__dict__['rewrite_graph_construction_error'] = (
errors.rewrite_graph_construction_error)

View File

@ -45,6 +45,7 @@ from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import errors
from tensorflow.python.autograph.core import function_wrapping
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import inspect_utils
@ -272,6 +273,7 @@ def _add_self_references(namespace, autograph_module):
# TODO(mdan): Add safeguards against name clashes.
# We don't want to create a submodule because we want the operators to be
# accessible as ag__.<operator>
ag_internal.__dict__.update(special_functions.__dict__)
ag_internal.__dict__.update(operators.__dict__)
_add_reserved_symbol(namespace, 'ag__', ag_internal)

View File

@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import data_structures
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_util
@ -46,6 +47,13 @@ def _validate_list_constructor(elements, element_dtype, element_shape):
' allowed'.format(type(elements)))
def match_staging_level(value, like_value):
"""Casts a value to be staged at the same level as another."""
if tensor_util.is_tensor(like_value):
return constant_op.constant(value)
return value
def tensor_list(elements,
element_dtype=None,
element_shape=None,

View File

@ -30,26 +30,35 @@ from tensorflow.python.platform import test
class SpecialFunctionsTest(test.TestCase):
def test_match_staging_level(self):
some_tensor = constant_op.constant(0)
tensor_one = special_functions.match_staging_level(1, some_tensor)
python_one = special_functions.match_staging_level(1, 1)
with self.cached_session() as sess:
self.assertTrue(tensor_util.is_tensor(tensor_one))
self.assertAllEqual(sess.run(tensor_one), 1)
self.assertEqual(python_one, 1)
def test_tensor_list_empty_list(self):
l = special_functions.tensor_list([],
element_dtype=dtypes.int32,
element_shape=())
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.test_session() as sess:
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [])
l = special_functions.tensor_list((),
element_dtype=dtypes.int32,
element_shape=())
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.test_session() as sess:
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [])
def test_tensor_list_tensor(self):
l = special_functions.tensor_list(
constant_op.constant([], dtype=dtypes.int32))
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.test_session() as sess:
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [])
def test_tensor_list_unsupported_initializer(self):
@ -66,7 +75,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.test_session() as sess:
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self):
@ -74,7 +83,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack()
with self.test_session() as sess:
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_stack(self):

View File

@ -198,6 +198,13 @@ class Annotator(transformer.Base):
node = self._block_statement_live_out(node)
return self._block_statement_live_in(node, node.test)
def visit_Expr(self, node):
node = self.generic_visit(node)
cfg_node = self.current_analyzer.graph.index[node]
anno.setanno(node, anno.Static.LIVE_VARS_OUT,
frozenset(self.current_analyzer.out[cfg_node]))
return node
def resolve(node, source_info, graphs):
"""Resolves the live symbols at the exit of control flow statements.

View File

@ -26,6 +26,7 @@ import six
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import templates
class AutographParseError(SyntaxError):
@ -280,6 +281,12 @@ class Base(gast.NodeTransformer):
print(pretty_printer.fmt(node))
return node
def create_assignment(self, target, expression):
template = """
target = expression
"""
return templates.replace(template, target=target, expression=expression)
def visit_block(self, nodes, before_visit=None, after_visit=None):
"""A more powerful version of generic_visit for statement blocks.
@ -316,13 +323,14 @@ class Base(gast.NodeTransformer):
Args:
nodes: enumerable of AST node objects. If None, the function returns None.
before_visit: optional callable that is called before visiting each item
in nodes
after_visit: optional callable that takes in an AST node and
returns a tuple (new_node, new_destination). It is called after
visiting each item in nodes. Is used in the same was as the
in nodes
after_visit: optional callable that takes in an AST node and returns a
tuple (new_node, new_destination). It is called after visiting each item
in nodes. Is used in the same was as the
visit_* methods: new_node will replace the node; if not None,
new_destination must be a list, and subsequent nodes will be placed
in this list instead of the list returned by visit_block.
new_destination must be a list, and subsequent nodes will be placed
in this list instead of the list returned by visit_block.
Returns:
A list of AST node objects containing the transformed items fron nodes,
except those nodes that have been relocated using after_visit.