From f4aa98af47b89a7c7920af4c21b7760073b9b5c6 Mon Sep 17 00:00:00 2001 From: Yanhua Sun Date: Wed, 14 Aug 2019 14:52:21 -0700 Subject: [PATCH] Fix while_v2 for eq change PiperOrigin-RevId: 263434775 --- tensorflow/python/ops/while_v2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index ea574514a81..47508873009 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -451,7 +451,8 @@ def _get_intermediates(func_graph): # 3. Do not accumulate loop vars that are returned as-is just like captured # tensors. intermediates = [] - reverse_captures = dict((v, k) for k, v in func_graph.captures) + reverse_captures = dict( + (v.experimental_ref(), k) for k, v in func_graph.captures) for op in func_graph.get_operations(): if op.type == "Identity": @@ -460,10 +461,11 @@ def _get_intermediates(func_graph): if op.type == "MutexLock": continue for o in op.outputs: - if (o != func_graph.inputs[0] and # Loop counter. + if (o is not func_graph.inputs[0] and # Loop counter. o.dtype != dtypes.resource and # Do not accumulate resource tensors. _get_accumulator(o) is None and # Has existing accumulator. - o not in reverse_captures): # Captured value, hence loop invariant. + o.experimental_ref() not in reverse_captures + ): # Captured value, hence loop invariant. intermediates.append(o) return intermediates