Optimize away the calculation of a range tensor for the pattern for i in tf.range. Along with the performance improvement, this is more compatible with XLA because it avoids generating dynamically-shaped tensors. Fixes .

PiperOrigin-RevId: 257394214
This commit is contained in:
Dan Moldovan 2019-07-10 06:47:55 -07:00 committed by TensorFlower Gardener
parent a216c03faf
commit dd3ec92602
6 changed files with 138 additions and 4 deletions

View File

@ -62,16 +62,19 @@ from __future__ import print_function
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.operators import special_values
from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.autograph.utils import misc
from tensorflow.python.autograph.utils import tensors
from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.experimental.ops import take_while_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
LIMIT_PYTHON_ITERATIONS = True
@ -137,8 +140,12 @@ def for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
Tuple containing the final state.
"""
if tensor_util.is_tensor(iter_):
return _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
init_vars)
if tensors.is_range_tensor(iter_):
return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
init_vars)
else:
return _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
set_state, init_vars)
if isinstance(iter_, dataset_ops.DatasetV2):
return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
@ -207,8 +214,59 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
init_vars=(0,) + init_vars,
opts=dict(maximum_iterations=n))
# Dropping the iteration index because it's not syntactically visible.
# TODO(mdan): Don't.
# Note: the iteration index is not returned by the while loop, however
# if a symbol with the same name exists outside the loop, it will be captured
# by the loop variables and ultimately updated correctly.
if isinstance(results, (tuple, list)):
assert len(results) >= 1 # Has at least the iterate.
if len(results) > 1:
results = results[1:]
else:
results = ()
return results
def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
init_vars):
"""Overload of for_stmt that iterates over a TF range (and elides it)."""
_disallow_undefs_into_loop(*init_vars)
start, limit, delta = iter_.op.inputs
def while_body(iterate, *loop_vars):
new_vars = body(iterate, *loop_vars)
loop_vars = (iterate + delta,)
if new_vars:
loop_vars += new_vars
return loop_vars
def while_cond(iterate, *loop_vars):
main_test = math_ops.logical_or(
math_ops.logical_and(delta >= 0, iterate < limit),
math_ops.logical_and(delta < 0, iterate > limit))
if extra_test is not None:
return control_flow_ops.cond(
main_test, lambda: extra_test(*loop_vars), lambda: False)
return main_test
# This specific dtype is required by while_loop.
maximum_iterations = math_ops.cast(
misc.get_range_len(start, limit, delta), dtypes.int32)
results = _tf_while_stmt(
while_cond,
while_body,
get_state,
set_state,
init_vars=(start,) + init_vars,
opts=dict(maximum_iterations=maximum_iterations))
# Note: the iteration index is not returned by the while loop, however
# if a symbol with the same name exists outside the loop, it will be captured
# by the loop variables and ultimately updated correctly.
if isinstance(results, (tuple, list)):
assert len(results) >= 1 # Has at least the iterate.
if len(results) > 1:

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -50,6 +51,39 @@ class ForLoopTest(test.TestCase):
init_vars=(0,))
self.assertEqual(self.evaluate(s), (1234,))
def test_range_tensor(self):
with ops.Graph().as_default():
s = control_flow.for_stmt(
math_ops.range(5),
extra_test=lambda s: True,
body=lambda i, s: (s * 10 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (1234,))
def test_range_tensor_explicit_limit_delta(self):
with ops.Graph().as_default():
s = control_flow.for_stmt(
math_ops.range(-17, -3, 5),
extra_test=lambda s: True,
body=lambda i, s: (s * 100 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (-171207,))
def test_range_tensor_negative_delta(self):
with ops.Graph().as_default():
s = control_flow.for_stmt(
math_ops.range(17, 3, -5),
extra_test=lambda s: True,
body=lambda i, s: (s * 100 + i,),
get_state=lambda: (),
set_state=lambda _: None,
init_vars=(0,))
self.assertEqual(self.evaluate(s), (171207,))
def test_tensor_with_extra_test_only_python_state(self):
class MutableObject(object):
field_1 = constant_op.constant(0, dtype=dtypes.int32)

View File

@ -20,6 +20,8 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
def alias_tensors(*args):
@ -55,3 +57,13 @@ def capitalize_initial(s):
if s:
return s[0].upper() + s[1:]
return s
def get_range_len(start, limit, delta):
dist = ops.convert_to_tensor(limit - start)
unadjusted_len = dist // delta
adjustment = math_ops.cast(
gen_math_ops.not_equal(dist % delta,
array_ops.zeros_like(unadjusted_len)), dist.dtype)
final_len = unadjusted_len + adjustment
return gen_math_ops.maximum(final_len, array_ops.zeros_like(final_len))

View File

@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.utils import misc
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.ops.variables import Variable
@ -61,6 +63,21 @@ class MiscTest(test.TestCase):
with self.cached_session() as sess:
self.assertEqual(1, self.evaluate(new_a))
def test_get_range_len(self):
get_range_as_graph = def_function.function(misc.get_range_len)
test_range = [(i, constant_op.constant(i)) for i in range(-3, 3)]
results = []
for i, ti in test_range:
for j, tj in test_range:
for k, tk in test_range:
if k == 0:
continue
results.append(((i, j, k), get_range_as_graph(ti, tj, tk)))
for (i, j, k), result_tensor in results:
self.assertEqual(
len(list(range(i, j, k))), self.evaluate(result_tensor))
if __name__ == '__main__':
test.main()

View File

@ -46,3 +46,8 @@ def is_tensor_list(t):
# construct.
return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and
not t.shape.ndims)
def is_range_tensor(t):
"""Returns True if a tensor is the result of a tf.range op. Best effort."""
return tensor_util.is_tensor(t) and hasattr(t, 'op') and t.op.type == 'Range'

View File

@ -22,6 +22,7 @@ from tensorflow.python.autograph.utils import tensors
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
@ -52,6 +53,13 @@ class TensorsTest(test.TestCase):
self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors()))
self.assertFalse(tensors.is_tensor_list(None))
def is_range_tensor(self):
self.assertTrue(tensors.is_range_tensor(math_ops.range(1)))
self.assertTrue(tensors.is_range_tensor(math_ops.range(1, 2)))
self.assertTrue(tensors.is_range_tensor(math_ops.range(1, 2, 3)))
self.assertFalse(tensors.is_range_tensor(None))
self.assertFalse(tensors.is_range_tensor(constant_op.constant(range(1))))
if __name__ == '__main__':
test.main()