Fix incorrect usage of tensor dictionary keys

PiperOrigin-RevId: 261622991
This commit is contained in:
Gaurav Jain 2019-08-04 23:01:58 -07:00 committed by TensorFlower Gardener
parent 864d2942fe
commit 00581c6347

View File

@ -931,15 +931,15 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
return captured_tensor
# Do not accumulate loop invariants.
if (tensor in self._forward_graph.inputs and
tensor in self._forward_graph.outputs):
if (any(tensor is t for t in self._forward_graph.inputs) and
any(tensor is t for t in self._forward_graph.outputs)):
captured_tensor = super(_WhileBodyGradFuncGraph,
self)._capture_helper(tensor, name)
# Add to `popped_tensor_lists` so that this gets added to the list of
# outputs.
# TODO(srbs): Rename popped_tensor_lists.
self.popped_tensor_lists[captured_tensor] = captured_tensor
self._indirect_captures[tensor] = captured_tensor
self.popped_tensor_lists[ops.tensor_id(captured_tensor)] = captured_tensor
self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
return captured_tensor
# Resource tensors are not accumulated and handled specially.