Optimize away the calculation of a range tensor for the pattern for i in tf.range
. Along with the performance improvement, this is more compatible with XLA because it avoids generating dynamically-shaped tensors. Fixes #30182.
PiperOrigin-RevId: 257394214
This commit is contained in:
parent
a216c03faf
commit
dd3ec92602
tensorflow/python/autograph
@ -62,16 +62,19 @@ from __future__ import print_function
|
||||
from tensorflow.python.autograph.operators import py_builtins
|
||||
from tensorflow.python.autograph.operators import special_values
|
||||
from tensorflow.python.autograph.utils import ag_logging
|
||||
from tensorflow.python.autograph.utils import misc
|
||||
from tensorflow.python.autograph.utils import tensors
|
||||
from tensorflow.python.data.experimental.ops import scan_ops
|
||||
from tensorflow.python.data.experimental.ops import take_while_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
|
||||
LIMIT_PYTHON_ITERATIONS = True
|
||||
@ -137,8 +140,12 @@ def for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
|
||||
Tuple containing the final state.
|
||||
"""
|
||||
if tensor_util.is_tensor(iter_):
|
||||
return _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
|
||||
init_vars)
|
||||
if tensors.is_range_tensor(iter_):
|
||||
return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
|
||||
init_vars)
|
||||
else:
|
||||
return _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
|
||||
set_state, init_vars)
|
||||
|
||||
if isinstance(iter_, dataset_ops.DatasetV2):
|
||||
return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
|
||||
@ -207,8 +214,59 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
|
||||
init_vars=(0,) + init_vars,
|
||||
opts=dict(maximum_iterations=n))
|
||||
|
||||
# Dropping the iteration index because it's not syntactically visible.
|
||||
# TODO(mdan): Don't.
|
||||
# Note: the iteration index is not returned by the while loop, however
|
||||
# if a symbol with the same name exists outside the loop, it will be captured
|
||||
# by the loop variables and ultimately updated correctly.
|
||||
if isinstance(results, (tuple, list)):
|
||||
assert len(results) >= 1 # Has at least the iterate.
|
||||
if len(results) > 1:
|
||||
results = results[1:]
|
||||
else:
|
||||
results = ()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
|
||||
init_vars):
|
||||
"""Overload of for_stmt that iterates over a TF range (and elides it)."""
|
||||
_disallow_undefs_into_loop(*init_vars)
|
||||
|
||||
start, limit, delta = iter_.op.inputs
|
||||
|
||||
def while_body(iterate, *loop_vars):
|
||||
new_vars = body(iterate, *loop_vars)
|
||||
|
||||
loop_vars = (iterate + delta,)
|
||||
if new_vars:
|
||||
loop_vars += new_vars
|
||||
|
||||
return loop_vars
|
||||
|
||||
def while_cond(iterate, *loop_vars):
|
||||
main_test = math_ops.logical_or(
|
||||
math_ops.logical_and(delta >= 0, iterate < limit),
|
||||
math_ops.logical_and(delta < 0, iterate > limit))
|
||||
if extra_test is not None:
|
||||
return control_flow_ops.cond(
|
||||
main_test, lambda: extra_test(*loop_vars), lambda: False)
|
||||
return main_test
|
||||
|
||||
# This specific dtype is required by while_loop.
|
||||
maximum_iterations = math_ops.cast(
|
||||
misc.get_range_len(start, limit, delta), dtypes.int32)
|
||||
|
||||
results = _tf_while_stmt(
|
||||
while_cond,
|
||||
while_body,
|
||||
get_state,
|
||||
set_state,
|
||||
init_vars=(start,) + init_vars,
|
||||
opts=dict(maximum_iterations=maximum_iterations))
|
||||
|
||||
# Note: the iteration index is not returned by the while loop, however
|
||||
# if a symbol with the same name exists outside the loop, it will be captured
|
||||
# by the loop variables and ultimately updated correctly.
|
||||
if isinstance(results, (tuple, list)):
|
||||
assert len(results) >= 1 # Has at least the iterate.
|
||||
if len(results) > 1:
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -50,6 +51,39 @@ class ForLoopTest(test.TestCase):
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
|
||||
def test_range_tensor(self):
|
||||
with ops.Graph().as_default():
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(5),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
|
||||
def test_range_tensor_explicit_limit_delta(self):
|
||||
with ops.Graph().as_default():
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(-17, -3, 5),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 100 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (-171207,))
|
||||
|
||||
def test_range_tensor_negative_delta(self):
|
||||
with ops.Graph().as_default():
|
||||
s = control_flow.for_stmt(
|
||||
math_ops.range(17, 3, -5),
|
||||
extra_test=lambda s: True,
|
||||
body=lambda i, s: (s * 100 + i,),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
init_vars=(0,))
|
||||
self.assertEqual(self.evaluate(s), (171207,))
|
||||
|
||||
def test_tensor_with_extra_test_only_python_state(self):
|
||||
class MutableObject(object):
|
||||
field_1 = constant_op.constant(0, dtype=dtypes.int32)
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def alias_tensors(*args):
|
||||
@ -55,3 +57,13 @@ def capitalize_initial(s):
|
||||
if s:
|
||||
return s[0].upper() + s[1:]
|
||||
return s
|
||||
|
||||
|
||||
def get_range_len(start, limit, delta):
|
||||
dist = ops.convert_to_tensor(limit - start)
|
||||
unadjusted_len = dist // delta
|
||||
adjustment = math_ops.cast(
|
||||
gen_math_ops.not_equal(dist % delta,
|
||||
array_ops.zeros_like(unadjusted_len)), dist.dtype)
|
||||
final_len = unadjusted_len + adjustment
|
||||
return gen_math_ops.maximum(final_len, array_ops.zeros_like(final_len))
|
||||
|
@ -19,6 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.utils import misc
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework.constant_op import constant
|
||||
from tensorflow.python.ops.variables import Variable
|
||||
@ -61,6 +63,21 @@ class MiscTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(1, self.evaluate(new_a))
|
||||
|
||||
def test_get_range_len(self):
|
||||
get_range_as_graph = def_function.function(misc.get_range_len)
|
||||
test_range = [(i, constant_op.constant(i)) for i in range(-3, 3)]
|
||||
results = []
|
||||
for i, ti in test_range:
|
||||
for j, tj in test_range:
|
||||
for k, tk in test_range:
|
||||
if k == 0:
|
||||
continue
|
||||
results.append(((i, j, k), get_range_as_graph(ti, tj, tk)))
|
||||
|
||||
for (i, j, k), result_tensor in results:
|
||||
self.assertEqual(
|
||||
len(list(range(i, j, k))), self.evaluate(result_tensor))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -46,3 +46,8 @@ def is_tensor_list(t):
|
||||
# construct.
|
||||
return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and
|
||||
not t.shape.ndims)
|
||||
|
||||
|
||||
def is_range_tensor(t):
|
||||
"""Returns True if a tensor is the result of a tf.range op. Best effort."""
|
||||
return tensor_util.is_tensor(t) and hasattr(t, 'op') and t.op.type == 'Range'
|
||||
|
@ -22,6 +22,7 @@ from tensorflow.python.autograph.utils import tensors
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -52,6 +53,13 @@ class TensorsTest(test.TestCase):
|
||||
self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors()))
|
||||
self.assertFalse(tensors.is_tensor_list(None))
|
||||
|
||||
def is_range_tensor(self):
|
||||
self.assertTrue(tensors.is_range_tensor(math_ops.range(1)))
|
||||
self.assertTrue(tensors.is_range_tensor(math_ops.range(1, 2)))
|
||||
self.assertTrue(tensors.is_range_tensor(math_ops.range(1, 2, 3)))
|
||||
self.assertFalse(tensors.is_range_tensor(None))
|
||||
self.assertFalse(tensors.is_range_tensor(constant_op.constant(range(1))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user