Cleanup for closed bugs: remove associated optimization patches. They no longer seem to have any effect. V2 GPU control flow remains ~20% slower than V1.

PiperOrigin-RevId: 287200087
Change-Id: Ia35f1bdaf9a0ff8b0081f78f853209f86acb010f
This commit is contained in:
Dan Moldovan 2019-12-26 10:16:07 -08:00 committed by TensorFlower Gardener
parent 95c535c6c2
commit 87f69493d2

View File

@ -79,7 +79,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
@ -410,9 +409,7 @@ def _known_len_tf_for_stmt(iter_,
lambda: False)
return iterate_index < n
# TODO(b/134181679): Let the op itself handle optimizations.
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = n
opts['maximum_iterations'] = n
results = _tf_while_stmt(
while_cond,
@ -526,26 +523,9 @@ def _tf_range_for_stmt(iter_,
def while_cond(iterate, *loop_vars):
"""Cond function for `tf.while_loop`."""
def build_main_test():
"""Main iteration condition."""
# TODO(b/138857806): The optimizer should handle this.
# LogicalAnd is slow on GPU so we avoid adding it if `delta` is a
# compile time constant.
delta_const = tensor_util.constant_value(delta)
if delta_const is not None:
# Support single element arrays.
delta_const = np.asscalar(delta_const)
if delta_const >= 0:
return iterate < limit
else:
return iterate > limit
else:
return math_ops.logical_or(
math_ops.logical_and(delta >= 0, iterate < limit),
math_ops.logical_and(delta < 0, iterate > limit))
main_test = build_main_test()
main_test = math_ops.logical_or(
math_ops.logical_and(delta >= 0, iterate < limit),
math_ops.logical_and(delta < 0, iterate > limit))
if extra_test is not None:
return control_flow_ops.cond(
main_test,
@ -554,11 +534,8 @@ def _tf_range_for_stmt(iter_,
)
return main_test
# TODO(b/134181679): The op should handle this optimizations.
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
# This specific dtype is required by while_loop.
opts['maximum_iterations'] = math_ops.cast(
misc.get_range_len(start, limit, delta), dtypes.int32)
opts['maximum_iterations'] = math_ops.cast(
misc.get_range_len(start, limit, delta), dtypes.int32)
results = _tf_while_stmt(
while_cond,