Extract the iterated expression of a for loop into a variable to avoid repeated staging.
PiperOrigin-RevId: 188316160
This commit is contained in:
parent
4ac1fee7f1
commit
51fd9d70b8
@ -51,7 +51,7 @@ class BuiltinFunctionTransformer(transformer.Base):
|
||||
def visit_Call(self, node):
|
||||
self.generic_visit(node)
|
||||
# TODO(mdan): This won't work if the function was hidden.
|
||||
if isinstance(node.func, gast.Name) and node.func.id in ('len',):
|
||||
if isinstance(node.func, gast.Name) and node.func.id in ('len', 'range'):
|
||||
return self._convert_builtin(node)
|
||||
# Print needs to be handled separately because it can be read as statement.
|
||||
if isinstance(node.func, gast.Name) and node.func.id == 'print':
|
||||
|
@ -37,14 +37,18 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
|
||||
def visit_For(self, node):
|
||||
self.generic_visit(node)
|
||||
body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
||||
|
||||
i_var = self.context.namer.new_symbol('i', body_scope.referenced)
|
||||
n_var = self.context.namer.new_symbol('n', body_scope.referenced)
|
||||
iterated_var = self.context.namer.new_symbol('iterated',
|
||||
body_scope.referenced)
|
||||
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
|
||||
if anno.hasanno(node, 'extra_cond'):
|
||||
template = """
|
||||
i = 0
|
||||
n = len(loop_iter)
|
||||
iterated = loop_iter
|
||||
n = len(iterated)
|
||||
while i < n and extra_cond:
|
||||
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
|
||||
target = loop_iter[i]
|
||||
target = iterated[i]
|
||||
body
|
||||
i += 1
|
||||
"""
|
||||
@ -53,17 +57,18 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
|
||||
loop_iter=node.iter,
|
||||
target=node.target,
|
||||
body=node.body,
|
||||
i=self.context.namer.new_symbol('i', body_scope.referenced),
|
||||
n=self.context.namer.new_symbol('n', body_scope.referenced),
|
||||
i=i_var,
|
||||
n=n_var,
|
||||
iterated=iterated_var,
|
||||
extra_cond=anno.getanno(node, 'extra_cond'))
|
||||
else:
|
||||
template = """
|
||||
i = 0
|
||||
n = len(loop_iter)
|
||||
iterated = loop_iter
|
||||
n = len(iterated)
|
||||
while i < n:
|
||||
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
|
||||
target = loop_iter[i]
|
||||
body # pylint:disable=pointless-statement
|
||||
target = iterated[i]
|
||||
body
|
||||
i += 1
|
||||
"""
|
||||
repl = templates.replace(
|
||||
@ -71,8 +76,9 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
|
||||
loop_iter=node.iter,
|
||||
target=node.target,
|
||||
body=node.body,
|
||||
i=self.context.namer.new_symbol('i', body_scope.referenced),
|
||||
n=self.context.namer.new_symbol('n', body_scope.referenced))
|
||||
i=i_var,
|
||||
n=n_var,
|
||||
iterated=iterated_var)
|
||||
return repl
|
||||
|
||||
def visit_Continue(self, node):
|
||||
|
@ -42,6 +42,29 @@ class ControlFlowTest(converter_test_base.TestCase):
|
||||
l = []
|
||||
self.assertEqual(test_fn(l), result.test_fn(l))
|
||||
|
||||
def test_for_with_iterated_expression(self):
|
||||
|
||||
eval_count = [0]
|
||||
|
||||
def count_evals(x):
|
||||
eval_count[0] += 1
|
||||
return x
|
||||
|
||||
def test_fn(n):
|
||||
s = 0
|
||||
for e in count_evals(range(n)):
|
||||
s += e
|
||||
return s
|
||||
|
||||
node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
|
||||
node = for_loops.transform(node, self.ctx)
|
||||
|
||||
with self.compiled(node) as result:
|
||||
result.count_evals = count_evals
|
||||
self.assertEqual(test_fn(5), result.test_fn(5))
|
||||
# count_evals ran twice, once for test_fn and another for result.test_fn
|
||||
self.assertEqual(eval_count[0], 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -20,11 +20,13 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.py2tf.utils.builtins import dynamic_builtin
|
||||
from tensorflow.contrib.py2tf.utils.builtins import dynamic_print
|
||||
from tensorflow.contrib.py2tf.utils.builtins import dynamic_range
|
||||
from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns
|
||||
from tensorflow.contrib.py2tf.utils.misc import alias_tensors
|
||||
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond
|
||||
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
|
||||
from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
|
||||
from tensorflow.contrib.py2tf.utils.tensor_list import dynamic_list_append
|
||||
from tensorflow.contrib.py2tf.utils.testing import fake_tf
|
||||
from tensorflow.contrib.py2tf.utils.type_check import is_tensor
|
||||
from tensorflow.contrib.py2tf.utils.type_hints import set_element_type
|
||||
|
Loading…
x
Reference in New Issue
Block a user