Fix while_v2 for eq change
PiperOrigin-RevId: 263434775
This commit is contained in:
parent
853ff441b7
commit
f4aa98af47
@ -451,7 +451,8 @@ def _get_intermediates(func_graph):
|
||||
# 3. Do not accumulate loop vars that are returned as-is just like captured
|
||||
# tensors.
|
||||
intermediates = []
|
||||
reverse_captures = dict((v, k) for k, v in func_graph.captures)
|
||||
reverse_captures = dict(
|
||||
(v.experimental_ref(), k) for k, v in func_graph.captures)
|
||||
|
||||
for op in func_graph.get_operations():
|
||||
if op.type == "Identity":
|
||||
@ -460,10 +461,11 @@ def _get_intermediates(func_graph):
|
||||
if op.type == "MutexLock":
|
||||
continue
|
||||
for o in op.outputs:
|
||||
if (o != func_graph.inputs[0] and # Loop counter.
|
||||
if (o is not func_graph.inputs[0] and # Loop counter.
|
||||
o.dtype != dtypes.resource and # Do not accumulate resource tensors.
|
||||
_get_accumulator(o) is None and # Has existing accumulator.
|
||||
o not in reverse_captures): # Captured value, hence loop invariant.
|
||||
o.experimental_ref() not in reverse_captures
|
||||
): # Captured value, hence loop invariant.
|
||||
intermediates.append(o)
|
||||
return intermediates
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user