Restore the safeguard around maximum_iterations, correctly this time.

PiperOrigin-RevId: 326217531
Change-Id: I313ca522ccfc7111d9977dbb3fc93de482debb0d
This commit is contained in:
Dan Moldovan 2020-08-12 06:09:09 -07:00 committed by TensorFlower Gardener
parent 445fccd1b0
commit 4e3283a891

View File

@ -81,6 +81,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
@ -445,6 +446,12 @@ def _py_for_stmt(iter_, extra_test, body, get_state, set_state):
body(target)
def _add_max_iterations_hint(opts, n):
# TODO(b/159186914): Remove the safeguard, and always set maximum_iterations.
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = n
def _known_len_tf_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts):
"""Overload of for_stmt that iterates over TF entities that admit a length."""
@ -478,7 +485,7 @@ def _known_len_tf_for_stmt(
return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test
opts['maximum_iterations'] = n
_add_max_iterations_hint(opts, n)
_tf_while_stmt(
aug_test,
@ -524,7 +531,7 @@ def _tf_ragged_for_stmt(
return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test
opts['maximum_iterations'] = n
_add_max_iterations_hint(opts, n)
_tf_while_stmt(
aug_test,
@ -582,8 +589,9 @@ def _tf_range_for_stmt(
main_test = control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test
opts['maximum_iterations'] = math_ops.cast(
misc.get_range_len(start, limit, delta), dtypes.int32)
_add_max_iterations_hint(
opts,
math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32))
_tf_while_stmt(
aug_test,