Restore the safeguard around maximum_iterations, correctly this time.
PiperOrigin-RevId: 326217531 Change-Id: I313ca522ccfc7111d9977dbb3fc93de482debb0d
This commit is contained in:
parent
445fccd1b0
commit
4e3283a891
@ -81,6 +81,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -445,6 +446,12 @@ def _py_for_stmt(iter_, extra_test, body, get_state, set_state):
|
||||
body(target)
|
||||
|
||||
|
||||
def _add_max_iterations_hint(opts, n):
|
||||
# TODO(b/159186914): Remove the safeguard, and always set maximum_iterations.
|
||||
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
|
||||
opts['maximum_iterations'] = n
|
||||
|
||||
|
||||
def _known_len_tf_for_stmt(
|
||||
iter_, extra_test, body, get_state, set_state, symbol_names, opts):
|
||||
"""Overload of for_stmt that iterates over TF entities that admit a length."""
|
||||
@ -478,7 +485,7 @@ def _known_len_tf_for_stmt(
|
||||
return control_flow_ops.cond(main_test, extra_test, lambda: False)
|
||||
return main_test
|
||||
|
||||
opts['maximum_iterations'] = n
|
||||
_add_max_iterations_hint(opts, n)
|
||||
|
||||
_tf_while_stmt(
|
||||
aug_test,
|
||||
@ -524,7 +531,7 @@ def _tf_ragged_for_stmt(
|
||||
return control_flow_ops.cond(main_test, extra_test, lambda: False)
|
||||
return main_test
|
||||
|
||||
opts['maximum_iterations'] = n
|
||||
_add_max_iterations_hint(opts, n)
|
||||
|
||||
_tf_while_stmt(
|
||||
aug_test,
|
||||
@ -582,8 +589,9 @@ def _tf_range_for_stmt(
|
||||
main_test = control_flow_ops.cond(main_test, extra_test, lambda: False)
|
||||
return main_test
|
||||
|
||||
opts['maximum_iterations'] = math_ops.cast(
|
||||
misc.get_range_len(start, limit, delta), dtypes.int32)
|
||||
_add_max_iterations_hint(
|
||||
opts,
|
||||
math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32))
|
||||
|
||||
_tf_while_stmt(
|
||||
aug_test,
|
||||
|
Loading…
Reference in New Issue
Block a user