Extract the iterated expression of a for loop into a variable to avoid repeated staging.

PiperOrigin-RevId: 188316160
This commit is contained in:
A. Unique TensorFlower 2018-03-08 05:12:20 -08:00 committed by TensorFlower Gardener
parent 4ac1fee7f1
commit 51fd9d70b8
4 changed files with 44 additions and 13 deletions

View File

@ -51,7 +51,7 @@ class BuiltinFunctionTransformer(transformer.Base):
def visit_Call(self, node):
self.generic_visit(node)
# TODO(mdan): This won't work if the function was hidden.
if isinstance(node.func, gast.Name) and node.func.id in ('len',):
if isinstance(node.func, gast.Name) and node.func.id in ('len', 'range'):
return self._convert_builtin(node)
# Print needs to be handled separately because it can be read as statement.
if isinstance(node.func, gast.Name) and node.func.id == 'print':

View File

@ -37,14 +37,18 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
def visit_For(self, node):
self.generic_visit(node)
body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
i_var = self.context.namer.new_symbol('i', body_scope.referenced)
n_var = self.context.namer.new_symbol('n', body_scope.referenced)
iterated_var = self.context.namer.new_symbol('iterated',
body_scope.referenced)
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
if anno.hasanno(node, 'extra_cond'):
template = """
i = 0
n = len(loop_iter)
iterated = loop_iter
n = len(iterated)
while i < n and extra_cond:
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
target = loop_iter[i]
target = iterated[i]
body
i += 1
"""
@ -53,17 +57,18 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
loop_iter=node.iter,
target=node.target,
body=node.body,
i=self.context.namer.new_symbol('i', body_scope.referenced),
n=self.context.namer.new_symbol('n', body_scope.referenced),
i=i_var,
n=n_var,
iterated=iterated_var,
extra_cond=anno.getanno(node, 'extra_cond'))
else:
template = """
i = 0
n = len(loop_iter)
iterated = loop_iter
n = len(iterated)
while i < n:
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
target = loop_iter[i]
body # pylint:disable=pointless-statement
target = iterated[i]
body
i += 1
"""
repl = templates.replace(
@ -71,8 +76,9 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
loop_iter=node.iter,
target=node.target,
body=node.body,
i=self.context.namer.new_symbol('i', body_scope.referenced),
n=self.context.namer.new_symbol('n', body_scope.referenced))
i=i_var,
n=n_var,
iterated=iterated_var)
return repl
def visit_Continue(self, node):

View File

@ -42,6 +42,29 @@ class ControlFlowTest(converter_test_base.TestCase):
l = []
self.assertEqual(test_fn(l), result.test_fn(l))
def test_for_with_iterated_expression(self):
eval_count = [0]
def count_evals(x):
eval_count[0] += 1
return x
def test_fn(n):
s = 0
for e in count_evals(range(n)):
s += e
return s
node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
node = for_loops.transform(node, self.ctx)
with self.compiled(node) as result:
result.count_evals = count_evals
self.assertEqual(test_fn(5), result.test_fn(5))
# count_evals ran twice, once for test_fn and another for result.test_fn
self.assertEqual(eval_count[0], 2)
if __name__ == '__main__':
test.main()

View File

@ -20,11 +20,13 @@ from __future__ import print_function
from tensorflow.contrib.py2tf.utils.builtins import dynamic_builtin
from tensorflow.contrib.py2tf.utils.builtins import dynamic_print
from tensorflow.contrib.py2tf.utils.builtins import dynamic_range
from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns
from tensorflow.contrib.py2tf.utils.misc import alias_tensors
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
from tensorflow.contrib.py2tf.utils.tensor_list import dynamic_list_append
from tensorflow.contrib.py2tf.utils.testing import fake_tf
from tensorflow.contrib.py2tf.utils.type_check import is_tensor
from tensorflow.contrib.py2tf.utils.type_hints import set_element_type