Fix while_v2 for eq change
PiperOrigin-RevId: 263434775
This commit is contained in:
parent
853ff441b7
commit
f4aa98af47
@ -451,7 +451,8 @@ def _get_intermediates(func_graph):
|
|||||||
# 3. Do not accumulate loop vars that are returned as-is just like captured
|
# 3. Do not accumulate loop vars that are returned as-is just like captured
|
||||||
# tensors.
|
# tensors.
|
||||||
intermediates = []
|
intermediates = []
|
||||||
reverse_captures = dict((v, k) for k, v in func_graph.captures)
|
reverse_captures = dict(
|
||||||
|
(v.experimental_ref(), k) for k, v in func_graph.captures)
|
||||||
|
|
||||||
for op in func_graph.get_operations():
|
for op in func_graph.get_operations():
|
||||||
if op.type == "Identity":
|
if op.type == "Identity":
|
||||||
@ -460,10 +461,11 @@ def _get_intermediates(func_graph):
|
|||||||
if op.type == "MutexLock":
|
if op.type == "MutexLock":
|
||||||
continue
|
continue
|
||||||
for o in op.outputs:
|
for o in op.outputs:
|
||||||
if (o != func_graph.inputs[0] and # Loop counter.
|
if (o is not func_graph.inputs[0] and # Loop counter.
|
||||||
o.dtype != dtypes.resource and # Do not accumulate resource tensors.
|
o.dtype != dtypes.resource and # Do not accumulate resource tensors.
|
||||||
_get_accumulator(o) is None and # Has existing accumulator.
|
_get_accumulator(o) is None and # Has existing accumulator.
|
||||||
o not in reverse_captures): # Captured value, hence loop invariant.
|
o.experimental_ref() not in reverse_captures
|
||||||
|
): # Captured value, hence loop invariant.
|
||||||
intermediates.append(o)
|
intermediates.append(o)
|
||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user