From 4e3283a891ff73078158059fe749db56f5447ac0 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Wed, 12 Aug 2020 06:09:09 -0700 Subject: [PATCH] Restore the safeguard around maximum_iterations, correctly this time. PiperOrigin-RevId: 326217531 Change-Id: I313ca522ccfc7111d9977dbb3fc93de482debb0d --- .../python/autograph/operators/control_flow.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index f194c446dc0..9bd139c031f 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -81,6 +81,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -445,6 +446,12 @@ def _py_for_stmt(iter_, extra_test, body, get_state, set_state): body(target) +def _add_max_iterations_hint(opts, n): + # TODO(b/159186914): Remove the safeguard, and always set maximum_iterations. + if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): + opts['maximum_iterations'] = n + + def _known_len_tf_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over TF entities that admit a length.""" @@ -478,7 +485,7 @@ def _known_len_tf_for_stmt( return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test - opts['maximum_iterations'] = n + _add_max_iterations_hint(opts, n) _tf_while_stmt( aug_test, @@ -524,7 +531,7 @@ def _tf_ragged_for_stmt( return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test - opts['maximum_iterations'] = n + _add_max_iterations_hint(opts, n) _tf_while_stmt( aug_test, @@ -582,8 +589,9 @@ def _tf_range_for_stmt( main_test = control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test - opts['maximum_iterations'] = math_ops.cast( - misc.get_range_len(start, limit, delta), dtypes.int32) + _add_max_iterations_hint( + opts, + math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32)) _tf_while_stmt( aug_test,