Fix while_v2 for eq change

PiperOrigin-RevId: 263434775
This commit is contained in:
Yanhua Sun 2019-08-14 14:52:21 -07:00 committed by Gaurav Jain
parent 853ff441b7
commit f4aa98af47

View File

@ -451,7 +451,8 @@ def _get_intermediates(func_graph):
# 3. Do not accumulate loop vars that are returned as-is just like captured
# tensors.
intermediates = []
reverse_captures = dict((v, k) for k, v in func_graph.captures)
reverse_captures = dict(
(v.experimental_ref(), k) for k, v in func_graph.captures)
for op in func_graph.get_operations():
if op.type == "Identity":
@ -460,10 +461,11 @@ def _get_intermediates(func_graph):
if op.type == "MutexLock":
continue
for o in op.outputs:
if (o != func_graph.inputs[0] and # Loop counter.
if (o is not func_graph.inputs[0] and # Loop counter.
o.dtype != dtypes.resource and # Do not accumulate resource tensors.
_get_accumulator(o) is None and # Has existing accumulator.
o not in reverse_captures): # Captured value, hence loop invariant.
o.experimental_ref() not in reverse_captures
): # Captured value, hence loop invariant.
intermediates.append(o)
return intermediates