Fix incorrect usage of tensor dictionary keys
PiperOrigin-RevId: 261622991
This commit is contained in:
parent
864d2942fe
commit
00581c6347
@ -931,15 +931,15 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
return captured_tensor
|
||||
|
||||
# Do not accumulate loop invariants.
|
||||
if (tensor in self._forward_graph.inputs and
|
||||
tensor in self._forward_graph.outputs):
|
||||
if (any(tensor is t for t in self._forward_graph.inputs) and
|
||||
any(tensor is t for t in self._forward_graph.outputs)):
|
||||
captured_tensor = super(_WhileBodyGradFuncGraph,
|
||||
self)._capture_helper(tensor, name)
|
||||
# Add to `popped_tensor_lists` so that this gets added to the list of
|
||||
# outputs.
|
||||
# TODO(srbs): Rename popped_tensor_lists.
|
||||
self.popped_tensor_lists[captured_tensor] = captured_tensor
|
||||
self._indirect_captures[tensor] = captured_tensor
|
||||
self.popped_tensor_lists[ops.tensor_id(captured_tensor)] = captured_tensor
|
||||
self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
|
||||
return captured_tensor
|
||||
|
||||
# Resource tensors are not accumulated and handled specially.
|
||||
|
Loading…
x
Reference in New Issue
Block a user