[TF:XLA] Add explicit padding support for Conv2D in tf2xla.
PiperOrigin-RevId: 232669459
This commit is contained in:
parent
9f6937e0d1
commit
5b5636642d
@ -203,7 +203,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes(
|
||||
StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
|
||||
const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
|
||||
absl::Span<const int32> dilations, const std::vector<int32>& strides,
|
||||
Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) {
|
||||
Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims,
|
||||
absl::Span<const int64> explicit_paddings) {
|
||||
TensorShape input_tensor_shape, filter_tensor_shape,
|
||||
out_backprop_tensor_shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
|
||||
@ -212,8 +213,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes(
|
||||
XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
|
||||
return ConvBackpropComputeDimensionsV2(
|
||||
label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
|
||||
out_backprop_tensor_shape, dilations, strides, padding,
|
||||
/*explicit_paddings=*/{}, data_format, dims);
|
||||
out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings,
|
||||
data_format, dims);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
@ -227,10 +228,9 @@ xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
|
||||
// TODO(reedwm): Support explicit padding.
|
||||
if (attrs.padding == EXPLICIT) {
|
||||
return errors::Unimplemented(
|
||||
"XLA does not yet support Conv2D with explicit padding.");
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
|
||||
}
|
||||
|
||||
string data_format;
|
||||
@ -303,6 +303,11 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
|
||||
window_strides[i] = attrs.strides.at(dim);
|
||||
rhs_dilation[i] = attrs.dilations.at(dim);
|
||||
|
||||
if (attrs.padding == EXPLICIT) {
|
||||
padding[i] = {attrs.explicit_paddings.at(dim * 2),
|
||||
attrs.explicit_paddings.at(dim * 2 + 1)};
|
||||
}
|
||||
|
||||
int64 unused_output_size;
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
|
||||
input_shape.dimensions(dim), filter_shape.dimensions(i),
|
||||
@ -337,7 +342,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
|
||||
type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
|
||||
out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
|
||||
attrs.data_format, &dims));
|
||||
attrs.data_format, &dims, attrs.explicit_paddings));
|
||||
|
||||
// The input gradients are computed by a convolution of the output
|
||||
// gradients and the filter, with some appropriate padding. See the
|
||||
@ -420,7 +425,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
|
||||
type_string, attrs.num_spatial_dims, activations_shape,
|
||||
expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
|
||||
attrs.padding, attrs.data_format, &dims));
|
||||
attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings));
|
||||
|
||||
// The activations (inputs) form the LHS of the convolution.
|
||||
// Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
|
||||
@ -469,6 +474,8 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
|
||||
dnums.add_input_spatial_dimensions(dim);
|
||||
dnums.add_kernel_spatial_dimensions(dim);
|
||||
rhs_dilation[i] = dims.spatial_dims[i].stride;
|
||||
window_strides[i] = attrs.dilations[dim];
|
||||
|
||||
// We will also need to pad the input with zeros such that after the
|
||||
// convolution, we get the right size for the filter.
|
||||
@ -495,6 +502,8 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
// We apply negative padding in this case.
|
||||
const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
|
||||
|
||||
// + For the EXPLICIT padding, we pad the top/left side with the explicit
|
||||
// padding and pad the bottom/right side with the remaining space.
|
||||
// + For the VALID padding, we don't pad anything on the top/left side
|
||||
// and pad the bottom/right side with the remaining space.
|
||||
// + For the SAME padding, we pad top/left side the same as bottom/right
|
||||
@ -503,12 +512,12 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
// In addition, if the padded input size is smaller than the input size,
|
||||
// we need to ignore some training elements of the input. We do this by
|
||||
// applying negative padding on the right/bottom.
|
||||
const int64 pad_before =
|
||||
attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
|
||||
|
||||
const int64 pad_before = attrs.padding == Padding::EXPLICIT
|
||||
? attrs.explicit_paddings[2 * dim]
|
||||
: attrs.padding == Padding::SAME
|
||||
? std::max<int64>(pad_total / 2, 0)
|
||||
: 0;
|
||||
padding[i] = {pad_before, pad_total - pad_before};
|
||||
rhs_dilation[i] = dims.spatial_dims[i].stride;
|
||||
window_strides[i] = attrs.dilations[dim];
|
||||
}
|
||||
|
||||
// Besides padding the input, we will also expand output_rows to
|
||||
|
@ -47,6 +47,7 @@ struct ConvOpAttrs {
|
||||
std::vector<int32> dilations;
|
||||
std::vector<int32> strides;
|
||||
Padding padding;
|
||||
std::vector<int64> explicit_paddings;
|
||||
TensorFormat data_format;
|
||||
};
|
||||
|
||||
|
@ -78,9 +78,6 @@ Status ConvBackpropExtractAndVerifyDimension(
|
||||
" stride: ", dim->stride, " dilation: ", dim->dilation);
|
||||
}
|
||||
|
||||
// TODO(reedwm): Correctly handle explicit padding here. The rest of the
|
||||
// fields set on 'dim' are only used in XLA. TensorFlow ops do not yet support
|
||||
// explicit padding for XLA.
|
||||
int64 effective_filter_size = (dim->filter_size - 1) * dim->dilation + 1;
|
||||
dim->expanded_output_size = (dim->output_size - 1) * dim->stride + 1;
|
||||
const auto padded_out_size = dim->input_size + effective_filter_size - 1;
|
||||
@ -102,7 +99,7 @@ Status ConvBackpropComputeDimensionsV2(
|
||||
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
|
||||
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
|
||||
Padding padding, const std::vector<int64>& explicit_paddings,
|
||||
Padding padding, absl::Span<const int64> explicit_paddings,
|
||||
TensorFormat data_format, ConvBackpropDimensions* dims) {
|
||||
// The + 2 in the following line is for the batch and feature dimensions.
|
||||
const int num_dims = num_spatial_dims + 2;
|
||||
|
@ -222,7 +222,7 @@ struct ConvBackpropSpatialDimension {
|
||||
int64 stride;
|
||||
int64 dilation;
|
||||
|
||||
// The following fields are valid only if the padding is not EXPLICIT.
|
||||
// Output size after scaling by the stride.
|
||||
int64 expanded_output_size;
|
||||
|
||||
// Number of padding elements to be added before/after this dimension of
|
||||
@ -270,7 +270,7 @@ Status ConvBackpropComputeDimensionsV2(
|
||||
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
|
||||
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
|
||||
Padding padding, const std::vector<int64>& explicit_paddings,
|
||||
Padding padding, absl::Span<const int64> explicit_paddings,
|
||||
TensorFormat data_format, ConvBackpropDimensions* dims);
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -574,7 +574,6 @@ class Conv2DTest(test.TestCase):
|
||||
padding="VALID")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D0x0Padding(self):
|
||||
self._VerifyExplicitPaddings(
|
||||
tensor_in_sizes=[1, 2, 3, 3],
|
||||
@ -589,7 +588,6 @@ class Conv2DTest(test.TestCase):
|
||||
padding=[[0, 0], [0, 0]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D1x1Padding(self):
|
||||
self._VerifyExplicitPaddings(
|
||||
tensor_in_sizes=[1, 2, 3, 2],
|
||||
@ -604,7 +602,6 @@ class Conv2DTest(test.TestCase):
|
||||
padding=[[1, 1], [1, 1]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Padding(self):
|
||||
self._VerifyExplicitPaddings(
|
||||
tensor_in_sizes=[1, 2, 1, 2],
|
||||
@ -619,7 +616,6 @@ class Conv2DTest(test.TestCase):
|
||||
padding=[[2, 2], [2, 2]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2DOnlyBottomPadding(self):
|
||||
self._VerifyExplicitPaddings(
|
||||
tensor_in_sizes=[1, 2, 3, 3],
|
||||
@ -634,7 +630,6 @@ class Conv2DTest(test.TestCase):
|
||||
padding=[[0, 3], [0, 0]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2DOnlyTopRightPadding(self):
|
||||
self._VerifyExplicitPaddings(
|
||||
tensor_in_sizes=[1, 2, 3, 3],
|
||||
@ -650,7 +645,6 @@ class Conv2DTest(test.TestCase):
|
||||
padding=[[1, 0], [0, 2]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2DLotsPadding(self):
|
||||
self._VerifyExplicitPaddings(
|
||||
tensor_in_sizes=[1, 1, 1, 3],
|
||||
@ -665,7 +659,6 @@ class Conv2DTest(test.TestCase):
|
||||
padding=[[3, 4], [4, 2]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2DExplicitPaddingWithDilations(self):
|
||||
self._VerifyExplicitPaddings(
|
||||
tensor_in_sizes=[1, 3, 2, 1],
|
||||
@ -681,7 +674,6 @@ class Conv2DTest(test.TestCase):
|
||||
padding=[[2, 1], [1, 2]],
|
||||
dilations=[2, 3])
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2DExplicitPaddingWithLayoutOptimizer(self):
|
||||
# Test with Grappler's layout optimizer, to ensure the layout optimizer
|
||||
# handles explicit padding correctly.
|
||||
@ -1349,7 +1341,6 @@ class Conv2DTest(test.TestCase):
|
||||
dilations=dilations)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding0x0BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1372,7 +1363,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding1x1BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1405,7 +1395,6 @@ class Conv2DTest(test.TestCase):
|
||||
dilations=[2, 2])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding2x2BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1430,7 +1419,6 @@ class Conv2DTest(test.TestCase):
|
||||
dilations=[2, 3])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding_1_8_4_1_BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1453,7 +1441,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding_5_0_2_2_BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1512,7 +1499,6 @@ class Conv2DTest(test.TestCase):
|
||||
err=err)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding0x0BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1535,7 +1521,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding1x1BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1569,7 +1554,6 @@ class Conv2DTest(test.TestCase):
|
||||
dilations=[2, 2])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding2x2BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1594,7 +1578,6 @@ class Conv2DTest(test.TestCase):
|
||||
dilations=[2, 3])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding_1_8_4_1_BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1618,7 +1601,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testConv2D2x2Depth1Padding_5_0_2_2_BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1976,7 +1958,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testInputGradient1x1PaddingStrideOne(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -1998,7 +1979,6 @@ class Conv2DTest(test.TestCase):
|
||||
use_gpu=use_gpu,
|
||||
max_err=0.0025)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testFilterGradient1x1PaddingStrideOne(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2019,7 +1999,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testInputGradient1x1PaddingStrideTwo(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2040,7 +2019,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testFilterGradient1x1PaddingStrideTwo(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2061,7 +2039,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testInputGradient2x2PaddingStrideOne(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2082,7 +2059,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testFilterGradient2x2PaddingStrideOne(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2104,7 +2080,6 @@ class Conv2DTest(test.TestCase):
|
||||
use_gpu=use_gpu,
|
||||
max_err=0.003)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testInputGradient1_2_3_4PaddingStride3x2(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2125,7 +2100,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testFilterGradient1_2_3_4PaddingStride3x2(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2146,7 +2120,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testInputGradient4_3_2_1PaddingStride2x1(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2167,7 +2140,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testFilterGradient4_3_2_1PaddingStride2x1(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2188,7 +2160,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testInputGradient0_0_0_5PaddingStride1x2(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2209,7 +2180,6 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
def testFilterGradient0_0_0_5PaddingStride1x2(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
@ -2316,7 +2286,7 @@ class Conv2DTest(test.TestCase):
|
||||
strides=[1, 1, 1, 1],
|
||||
padding=[0, 0, 0, 0])
|
||||
|
||||
@test_util.disable_xla("This test never passed for XLA")
|
||||
@test_util.disable_xla("b/123337890") # Error messages differ
|
||||
def testOpEdgeCases(self):
|
||||
with self.cached_session() as sess:
|
||||
# Illegal strides.
|
||||
|
Loading…
Reference in New Issue
Block a user