Cleanup for closed bugs: remove associated optimization patches. They no longer seem to have any effect. V2 GPU control flow remains ~20% slower than V1.
PiperOrigin-RevId: 287200087 Change-Id: Ia35f1bdaf9a0ff8b0081f78f853209f86acb010f
This commit is contained in:
parent
95c535c6c2
commit
87f69493d2
@ -79,7 +79,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -410,9 +409,7 @@ def _known_len_tf_for_stmt(iter_,
|
||||
lambda: False)
|
||||
return iterate_index < n
|
||||
|
||||
# TODO(b/134181679): Let the op itself handle optimizations.
|
||||
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
|
||||
opts['maximum_iterations'] = n
|
||||
opts['maximum_iterations'] = n
|
||||
|
||||
results = _tf_while_stmt(
|
||||
while_cond,
|
||||
@ -526,26 +523,9 @@ def _tf_range_for_stmt(iter_,
|
||||
|
||||
def while_cond(iterate, *loop_vars):
|
||||
"""Cond function for `tf.while_loop`."""
|
||||
|
||||
def build_main_test():
|
||||
"""Main iteration condition."""
|
||||
# TODO(b/138857806): The optimizer should handle this.
|
||||
# LogicalAnd is slow on GPU so we avoid adding it if `delta` is a
|
||||
# compile time constant.
|
||||
delta_const = tensor_util.constant_value(delta)
|
||||
if delta_const is not None:
|
||||
# Support single element arrays.
|
||||
delta_const = np.asscalar(delta_const)
|
||||
if delta_const >= 0:
|
||||
return iterate < limit
|
||||
else:
|
||||
return iterate > limit
|
||||
else:
|
||||
return math_ops.logical_or(
|
||||
math_ops.logical_and(delta >= 0, iterate < limit),
|
||||
math_ops.logical_and(delta < 0, iterate > limit))
|
||||
|
||||
main_test = build_main_test()
|
||||
main_test = math_ops.logical_or(
|
||||
math_ops.logical_and(delta >= 0, iterate < limit),
|
||||
math_ops.logical_and(delta < 0, iterate > limit))
|
||||
if extra_test is not None:
|
||||
return control_flow_ops.cond(
|
||||
main_test,
|
||||
@ -554,11 +534,8 @@ def _tf_range_for_stmt(iter_,
|
||||
)
|
||||
return main_test
|
||||
|
||||
# TODO(b/134181679): The op should handle this optimizations.
|
||||
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
|
||||
# This specific dtype is required by while_loop.
|
||||
opts['maximum_iterations'] = math_ops.cast(
|
||||
misc.get_range_len(start, limit, delta), dtypes.int32)
|
||||
opts['maximum_iterations'] = math_ops.cast(
|
||||
misc.get_range_len(start, limit, delta), dtypes.int32)
|
||||
|
||||
results = _tf_while_stmt(
|
||||
while_cond,
|
||||
|
Loading…
Reference in New Issue
Block a user