Add a workaround for slowness of tf.while_loop in the default executor when maximum_iterations is set. Fixes #40517.

PiperOrigin-RevId: 317753132
Change-Id: I7b5767019ceebd3f21975a990e7c2c05dd878ca6
This commit is contained in:
Dan Moldovan 2020-06-22 15:58:38 -07:00 committed by TensorFlower Gardener
parent 23d5a2e00a
commit 094340c9d4
1 changed files with 11 additions and 4 deletions

View File

@ -80,6 +80,7 @@ from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
@ -429,7 +430,9 @@ def _known_len_tf_for_stmt(
return control_flow_ops.cond(main_test, extra_test, lambda: False) return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test return main_test
opts['maximum_iterations'] = n # TODO(b/159186914): Remove.
if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = n
_tf_while_stmt( _tf_while_stmt(
aug_test, aug_test,
@ -475,7 +478,9 @@ def _tf_ragged_for_stmt(
return control_flow_ops.cond(main_test, extra_test, lambda: False) return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test return main_test
opts['maximum_iterations'] = n # TODO(b/159186914): Remove.
if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = n
_tf_while_stmt( _tf_while_stmt(
aug_test, aug_test,
@ -524,8 +529,10 @@ def _tf_range_for_stmt(
return control_flow_ops.cond(main_test, extra_test, lambda: False) return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test return main_test
opts['maximum_iterations'] = math_ops.cast( # TODO(b/159186914): Remove.
misc.get_range_len(start, limit, delta), dtypes.int32) if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = math_ops.cast(
misc.get_range_len(start, limit, delta), dtypes.int32)
_tf_while_stmt( _tf_while_stmt(
aug_test, aug_test,