Use a smart constant instead tf.constant -- one that matches the condition value, in the sense that it will either be a Tensor or a Python value.
This change also required adjusting the way side_effect_guards works, which now will gate only the live symbols. A third change adds the live symbols to expression nodes. Motivation: pure-side-effects control flow statements are being augmented with a dummy return value (they just return a single scalar). This provides a return value that can be used as control dependency by side_effect_guards. Such dependency necessarily needs to be a tensor, but we don't want to return a tensor unless the if statement is staged to a cond. This CL makes that return value consistent with the statement itself. PiperOrigin-RevId: 220506051
This commit is contained in:
parent
5b373961c6
commit
81f1d6751b
tensorflow/python/autograph
converters
core
impl
lang
pyct
@ -160,6 +160,10 @@ class ControlFlowTransformer(converter.Base):
|
||||
node_body = ast_util.rename_symbols(node.body, alias_body_map)
|
||||
node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
|
||||
|
||||
cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
|
||||
body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
|
||||
orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
|
||||
|
||||
returned_from_cond = tuple(returned_from_cond)
|
||||
if returned_from_cond:
|
||||
if len(returned_from_cond) == 1:
|
||||
@ -181,13 +185,14 @@ class ControlFlowTransformer(converter.Base):
|
||||
# actually has some return value as well.
|
||||
cond_results = None
|
||||
# TODO(mdan): This doesn't belong here; it's specific to the operator.
|
||||
returned_from_body = (templates.replace_as_expression('tf.constant(1)'),)
|
||||
returned_from_orelse = (
|
||||
templates.replace_as_expression('tf.constant(1)'),)
|
||||
|
||||
body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
|
||||
orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
|
||||
returned_from_body = (templates.replace_as_expression(
|
||||
'ag__.match_staging_level(1, cond_var_name)',
|
||||
cond_var_name=cond_var_name),)
|
||||
returned_from_orelse = (templates.replace_as_expression(
|
||||
'ag__.match_staging_level(1, cond_var_name)',
|
||||
cond_var_name=cond_var_name),)
|
||||
|
||||
cond_assign = self.create_assignment(cond_var_name, node.test)
|
||||
body_def = self._create_cond_branch(
|
||||
body_name,
|
||||
aliased_orig_names=aliased_body_orig_names,
|
||||
@ -200,10 +205,10 @@ class ControlFlowTransformer(converter.Base):
|
||||
aliased_new_names=aliased_orelse_new_names,
|
||||
body=node_orelse,
|
||||
returns=returned_from_orelse)
|
||||
cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
|
||||
cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
|
||||
orelse_name)
|
||||
|
||||
return body_def + orelse_def + cond_expr
|
||||
return cond_assign + body_def + orelse_def + cond_expr
|
||||
|
||||
def _get_loop_state(self, node):
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
|
@ -122,11 +122,12 @@ class SideEffectGuardTransformer(converter.Base):
|
||||
# possible, gate all remaining statements (and that may fail too, see
|
||||
# _visit_and_reindent.
|
||||
args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
|
||||
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
|
||||
# NOTE: We can't guard object attributes because they may not be writable.
|
||||
# In addition, avoid renaming well-known names.
|
||||
# TODO(mdan): Move these names into config.
|
||||
unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
|
||||
guarded_args = tuple(s for s in args_scope.read
|
||||
unguarded_names = (qual_names.QN('self'), qual_names.QN('ag__'))
|
||||
guarded_args = tuple(s for s in live_out
|
||||
if not s.is_composite() and s not in unguarded_names)
|
||||
|
||||
# TODO(mdan): Include all arguments which depended on guarded_args too.
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.autograph.core import config
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import errors
|
||||
from tensorflow.python.autograph.core import function_wrapping
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.pyct import compiler
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
@ -103,6 +104,7 @@ class TestCase(test.TestCase):
|
||||
fake_ag = self.make_fake_mod('fake_ag', converted_call,
|
||||
converter.ConversionOptions)
|
||||
fake_ag.__dict__.update(operators.__dict__)
|
||||
fake_ag.__dict__.update(special_functions.__dict__)
|
||||
fake_ag.__dict__['utils'] = utils
|
||||
fake_ag.__dict__['rewrite_graph_construction_error'] = (
|
||||
errors.rewrite_graph_construction_error)
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.autograph.core import config
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import errors
|
||||
from tensorflow.python.autograph.core import function_wrapping
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import compiler
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
@ -272,6 +273,7 @@ def _add_self_references(namespace, autograph_module):
|
||||
# TODO(mdan): Add safeguards against name clashes.
|
||||
# We don't want to create a submodule because we want the operators to be
|
||||
# accessible as ag__.<operator>
|
||||
ag_internal.__dict__.update(special_functions.__dict__)
|
||||
ag_internal.__dict__.update(operators.__dict__)
|
||||
|
||||
_add_reserved_symbol(namespace, 'ag__', ag_internal)
|
||||
|
@ -24,6 +24,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.operators import data_structures
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
||||
|
||||
@ -46,6 +47,13 @@ def _validate_list_constructor(elements, element_dtype, element_shape):
|
||||
' allowed'.format(type(elements)))
|
||||
|
||||
|
||||
def match_staging_level(value, like_value):
|
||||
"""Casts a value to be staged at the same level as another."""
|
||||
if tensor_util.is_tensor(like_value):
|
||||
return constant_op.constant(value)
|
||||
return value
|
||||
|
||||
|
||||
def tensor_list(elements,
|
||||
element_dtype=None,
|
||||
element_shape=None,
|
||||
|
@ -30,26 +30,35 @@ from tensorflow.python.platform import test
|
||||
|
||||
class SpecialFunctionsTest(test.TestCase):
|
||||
|
||||
def test_match_staging_level(self):
|
||||
some_tensor = constant_op.constant(0)
|
||||
tensor_one = special_functions.match_staging_level(1, some_tensor)
|
||||
python_one = special_functions.match_staging_level(1, 1)
|
||||
with self.cached_session() as sess:
|
||||
self.assertTrue(tensor_util.is_tensor(tensor_one))
|
||||
self.assertAllEqual(sess.run(tensor_one), 1)
|
||||
self.assertEqual(python_one, 1)
|
||||
|
||||
def test_tensor_list_empty_list(self):
|
||||
l = special_functions.tensor_list([],
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
|
||||
l = special_functions.tensor_list((),
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
|
||||
def test_tensor_list_tensor(self):
|
||||
l = special_functions.tensor_list(
|
||||
constant_op.constant([], dtype=dtypes.int32))
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
|
||||
def test_tensor_list_unsupported_initializer(self):
|
||||
@ -66,7 +75,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
|
||||
l = special_functions.tensor_list(elements)
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_tensor_list_array_from_elements(self):
|
||||
@ -74,7 +83,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
|
||||
l = special_functions.tensor_list(elements, use_tensor_array=True)
|
||||
sl = l.stack()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_stack(self):
|
||||
|
@ -198,6 +198,13 @@ class Annotator(transformer.Base):
|
||||
node = self._block_statement_live_out(node)
|
||||
return self._block_statement_live_in(node, node.test)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
node = self.generic_visit(node)
|
||||
cfg_node = self.current_analyzer.graph.index[node]
|
||||
anno.setanno(node, anno.Static.LIVE_VARS_OUT,
|
||||
frozenset(self.current_analyzer.out[cfg_node]))
|
||||
return node
|
||||
|
||||
|
||||
def resolve(node, source_info, graphs):
|
||||
"""Resolves the live symbols at the exit of control flow statements.
|
||||
|
@ -26,6 +26,7 @@ import six
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import compiler
|
||||
from tensorflow.python.autograph.pyct import pretty_printer
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
|
||||
|
||||
class AutographParseError(SyntaxError):
|
||||
@ -280,6 +281,12 @@ class Base(gast.NodeTransformer):
|
||||
print(pretty_printer.fmt(node))
|
||||
return node
|
||||
|
||||
def create_assignment(self, target, expression):
|
||||
template = """
|
||||
target = expression
|
||||
"""
|
||||
return templates.replace(template, target=target, expression=expression)
|
||||
|
||||
def visit_block(self, nodes, before_visit=None, after_visit=None):
|
||||
"""A more powerful version of generic_visit for statement blocks.
|
||||
|
||||
@ -316,13 +323,14 @@ class Base(gast.NodeTransformer):
|
||||
Args:
|
||||
nodes: enumerable of AST node objects. If None, the function returns None.
|
||||
before_visit: optional callable that is called before visiting each item
|
||||
in nodes
|
||||
after_visit: optional callable that takes in an AST node and
|
||||
returns a tuple (new_node, new_destination). It is called after
|
||||
visiting each item in nodes. Is used in the same was as the
|
||||
in nodes
|
||||
after_visit: optional callable that takes in an AST node and returns a
|
||||
tuple (new_node, new_destination). It is called after visiting each item
|
||||
in nodes. Is used in the same was as the
|
||||
visit_* methods: new_node will replace the node; if not None,
|
||||
new_destination must be a list, and subsequent nodes will be placed
|
||||
in this list instead of the list returned by visit_block.
|
||||
new_destination must be a list, and subsequent nodes will be placed
|
||||
in this list instead of the list returned by visit_block.
|
||||
|
||||
Returns:
|
||||
A list of AST node objects containing the transformed items fron nodes,
|
||||
except those nodes that have been relocated using after_visit.
|
||||
|
Loading…
Reference in New Issue
Block a user