Cleanup for closed bugs: remove associated optimization patches. They no longer seem to have any effect. V2 GPU control flow remains ~20% slower than V1.

PiperOrigin-RevId: 287200087
Change-Id: Ia35f1bdaf9a0ff8b0081f78f853209f86acb010f
This commit is contained in:
Dan Moldovan 2019-12-26 10:16:07 -08:00 committed by TensorFlower Gardener
parent 95c535c6c2
commit 87f69493d2

View File

@ -79,7 +79,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
@ -410,9 +409,7 @@ def _known_len_tf_for_stmt(iter_,
lambda: False) lambda: False)
return iterate_index < n return iterate_index < n
# TODO(b/134181679): Let the op itself handle optimizations. opts['maximum_iterations'] = n
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = n
results = _tf_while_stmt( results = _tf_while_stmt(
while_cond, while_cond,
@ -526,26 +523,9 @@ def _tf_range_for_stmt(iter_,
def while_cond(iterate, *loop_vars): def while_cond(iterate, *loop_vars):
"""Cond function for `tf.while_loop`.""" """Cond function for `tf.while_loop`."""
main_test = math_ops.logical_or(
def build_main_test(): math_ops.logical_and(delta >= 0, iterate < limit),
"""Main iteration condition.""" math_ops.logical_and(delta < 0, iterate > limit))
# TODO(b/138857806): The optimizer should handle this.
# LogicalAnd is slow on GPU so we avoid adding it if `delta` is a
# compile time constant.
delta_const = tensor_util.constant_value(delta)
if delta_const is not None:
# Support single element arrays.
delta_const = np.asscalar(delta_const)
if delta_const >= 0:
return iterate < limit
else:
return iterate > limit
else:
return math_ops.logical_or(
math_ops.logical_and(delta >= 0, iterate < limit),
math_ops.logical_and(delta < 0, iterate > limit))
main_test = build_main_test()
if extra_test is not None: if extra_test is not None:
return control_flow_ops.cond( return control_flow_ops.cond(
main_test, main_test,
@ -554,11 +534,8 @@ def _tf_range_for_stmt(iter_,
) )
return main_test return main_test
# TODO(b/134181679): The op should handle this optimizations. opts['maximum_iterations'] = math_ops.cast(
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): misc.get_range_len(start, limit, delta), dtypes.int32)
# This specific dtype is required by while_loop.
opts['maximum_iterations'] = math_ops.cast(
misc.get_range_len(start, limit, delta), dtypes.int32)
results = _tf_while_stmt( results = _tf_while_stmt(
while_cond, while_cond,