Add a workaround for slowness of tf.while_loop in the default executor when maximum_iterations is set. Fixes #40517.
PiperOrigin-RevId: 317753132 Change-Id: I7b5767019ceebd3f21975a990e7c2c05dd878ca6
This commit is contained in:
parent
23d5a2e00a
commit
094340c9d4
|
@ -80,6 +80,7 @@ from tensorflow.python.framework import func_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
@ -429,6 +430,8 @@ def _known_len_tf_for_stmt(
|
||||||
return control_flow_ops.cond(main_test, extra_test, lambda: False)
|
return control_flow_ops.cond(main_test, extra_test, lambda: False)
|
||||||
return main_test
|
return main_test
|
||||||
|
|
||||||
|
# TODO(b/159186914): Remove.
|
||||||
|
if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
|
||||||
opts['maximum_iterations'] = n
|
opts['maximum_iterations'] = n
|
||||||
|
|
||||||
_tf_while_stmt(
|
_tf_while_stmt(
|
||||||
|
@ -475,6 +478,8 @@ def _tf_ragged_for_stmt(
|
||||||
return control_flow_ops.cond(main_test, extra_test, lambda: False)
|
return control_flow_ops.cond(main_test, extra_test, lambda: False)
|
||||||
return main_test
|
return main_test
|
||||||
|
|
||||||
|
# TODO(b/159186914): Remove.
|
||||||
|
if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
|
||||||
opts['maximum_iterations'] = n
|
opts['maximum_iterations'] = n
|
||||||
|
|
||||||
_tf_while_stmt(
|
_tf_while_stmt(
|
||||||
|
@ -524,6 +529,8 @@ def _tf_range_for_stmt(
|
||||||
return control_flow_ops.cond(main_test, extra_test, lambda: False)
|
return control_flow_ops.cond(main_test, extra_test, lambda: False)
|
||||||
return main_test
|
return main_test
|
||||||
|
|
||||||
|
# TODO(b/159186914): Remove.
|
||||||
|
if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
|
||||||
opts['maximum_iterations'] = math_ops.cast(
|
opts['maximum_iterations'] = math_ops.cast(
|
||||||
misc.get_range_len(start, limit, delta), dtypes.int32)
|
misc.get_range_len(start, limit, delta), dtypes.int32)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue