Fix race condition in TensorArrayWrite grad.

It was possible for a TensorArrayWrite/Scatter/Split to occur after its corresponding TensorArrayGrad op was executed. The TensorArrayGrad op must execute after the last write so that it is created with the correct size. Since TensorArrayGrad converts its source TensorArray to fixed-size, this causes an exception when the final TensorArrayWrite executes since it cannot grow the TensorArray. This fix introduces a control dependency on the TensorArrayWrite op, to ensure TensorArrayGrad runs after the last write.

Also, fix a similar pattern in TensorArrayScatter and TensorArraySplit's gradients even though the race has not been observed for them.

PiperOrigin-RevId: 303515506
Change-Id: Ia363b94922b8855f01faa801818839559d55b0da
This commit is contained in:
RJ Skerry-Ryan 2020-03-28 10:02:22 -07:00 committed by TensorFlower Gardener
parent a1a7af994d
commit e24331bf11
2 changed files with 20 additions and 10 deletions

View File

@ -410,7 +410,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 469> a = {{
static std::array<OpIndexInfo, 460> a = {{
{"Abs"},
{"AccumulateNV2"},
{"Acos"},
@ -821,20 +821,11 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
{"TensorArrayRead"},
{"TensorArrayReadV2"},
{"TensorArrayReadV3"},
{"TensorArrayScatter"},
{"TensorArrayScatterV2"},
{"TensorArrayScatterV3"},
{"TensorArraySize"},
{"TensorArraySizeV2"},
{"TensorArraySizeV3"},
{"TensorArraySplit"},
{"TensorArraySplitV2"},
{"TensorArraySplitV3"},
{"TensorArrayV2"},
{"TensorArrayV3"},
{"TensorArrayWrite"},
{"TensorArrayWriteV2"},
{"TensorArrayWriteV3"},
{"TensorListConcat", 1, {0}},
{"TensorListConcatLists"},
{"TensorListConcatV2", 1, {0}},

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import tensor_array_ops
# TODO(b/31222613): These ops may be differentiable, and there may be
@ -130,6 +131,12 @@ def _TensorArrayWriteGrad(op, flow):
index = op.inputs[1]
dtype = op.get_attr("T")
grad_source = _GetGradSource(flow)
flow_out = array_ops.identity(op.outputs[0], "flow_out")
# Avoid a race condition where the TensorArrayGrad op is executed before the
# final TensorArrayWrite by adding a control dependency on the output flow of
# the write to the input flow to the TensorArrayGrad.
with ops.control_dependencies([flow_out]):
flow = array_ops.identity(flow, "write_barrier")
g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow,
colocate_with_first_write_call=False)
.grad(source=grad_source, flow=flow))
@ -185,6 +192,12 @@ def _TensorArrayScatterGrad(op, flow):
indices = op.inputs[1]
dtype = op.get_attr("T")
grad_source = _GetGradSource(flow)
flow_out = array_ops.identity(op.outputs[0], "flow_out")
# Avoid a race condition where the TensorArrayGrad op is executed before the
# TensorArrayScatter by adding a control dependency on the output flow of
# the scatter to the input flow to the TensorArrayGrad.
with ops.control_dependencies([flow_out]):
flow = array_ops.identity(flow, "write_barrier")
g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow,
colocate_with_first_write_call=False)
.grad(source=grad_source, flow=flow))
@ -240,6 +253,12 @@ def _TensorArraySplitGrad(op, flow):
handle = op.inputs[0]
dtype = op.get_attr("T")
grad_source = _GetGradSource(flow)
flow_out = array_ops.identity(op.outputs[0], "flow_out")
# Avoid a race condition where the TensorArrayGrad op is executed before the
# TensorArraySplit by adding a control dependency on the output flow of
# the split to the input flow to the TensorArrayGrad.
with ops.control_dependencies([flow_out]):
flow = array_ops.identity(flow, "write_barrier")
g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow,
colocate_with_first_write_call=False)
.grad(source=grad_source, flow=flow))