From 87f69493d2553ceb9fdb94ca7f203dc1c8e417ea Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Thu, 26 Dec 2019 10:16:07 -0800 Subject: [PATCH] Cleanup for closed bugs: remove associated optimization patches. They no longer seem to have any effect. V2 GPU control flow remains ~20% slower than V1. PiperOrigin-RevId: 287200087 Change-Id: Ia35f1bdaf9a0ff8b0081f78f853209f86acb010f --- .../autograph/operators/control_flow.py | 35 ++++--------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index c994c081d92..f48bacc3dd2 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -79,7 +79,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -410,9 +409,7 @@ def _known_len_tf_for_stmt(iter_, lambda: False) return iterate_index < n - # TODO(b/134181679): Let the op itself handle optimizations. - if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): - opts['maximum_iterations'] = n + opts['maximum_iterations'] = n results = _tf_while_stmt( while_cond, @@ -526,26 +523,9 @@ def _tf_range_for_stmt(iter_, def while_cond(iterate, *loop_vars): """Cond function for `tf.while_loop`.""" - - def build_main_test(): - """Main iteration condition.""" - # TODO(b/138857806): The optimizer should handle this. - # LogicalAnd is slow on GPU so we avoid adding it if `delta` is a - # compile time constant. - delta_const = tensor_util.constant_value(delta) - if delta_const is not None: - # Support single element arrays. - delta_const = np.asscalar(delta_const) - if delta_const >= 0: - return iterate < limit - else: - return iterate > limit - else: - return math_ops.logical_or( - math_ops.logical_and(delta >= 0, iterate < limit), - math_ops.logical_and(delta < 0, iterate > limit)) - - main_test = build_main_test() + main_test = math_ops.logical_or( + math_ops.logical_and(delta >= 0, iterate < limit), + math_ops.logical_and(delta < 0, iterate > limit)) if extra_test is not None: return control_flow_ops.cond( main_test, @@ -554,11 +534,8 @@ def _tf_range_for_stmt(iter_, ) return main_test - # TODO(b/134181679): The op should handle this optimizations. - if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): - # This specific dtype is required by while_loop. - opts['maximum_iterations'] = math_ops.cast( - misc.get_range_len(start, limit, delta), dtypes.int32) + opts['maximum_iterations'] = math_ops.cast( + misc.get_range_len(start, limit, delta), dtypes.int32) results = _tf_while_stmt( while_cond,