Add a workaround for slowness of tf.while_loop in the default executor when maximum_iterations is set. Fixes #40517.

PiperOrigin-RevId: 317753132
Change-Id: I7b5767019ceebd3f21975a990e7c2c05dd878ca6
This commit is contained in:
Dan Moldovan 2020-06-22 15:58:38 -07:00 committed by TensorFlower Gardener
parent 23d5a2e00a
commit 094340c9d4
1 changed files with 11 additions and 4 deletions

View File

@ -80,6 +80,7 @@ from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
@ -429,6 +430,8 @@ def _known_len_tf_for_stmt(
return control_flow_ops.cond(main_test, extra_test, lambda: False) return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test return main_test
# TODO(b/159186914): Remove.
if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = n opts['maximum_iterations'] = n
_tf_while_stmt( _tf_while_stmt(
@ -475,6 +478,8 @@ def _tf_ragged_for_stmt(
return control_flow_ops.cond(main_test, extra_test, lambda: False) return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test return main_test
# TODO(b/159186914): Remove.
if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = n opts['maximum_iterations'] = n
_tf_while_stmt( _tf_while_stmt(
@ -524,6 +529,8 @@ def _tf_range_for_stmt(
return control_flow_ops.cond(main_test, extra_test, lambda: False) return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test return main_test
# TODO(b/159186914): Remove.
if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = math_ops.cast( opts['maximum_iterations'] = math_ops.cast(
misc.get_range_len(start, limit, delta), dtypes.int32) misc.get_range_len(start, limit, delta), dtypes.int32)