STT-tensorflow/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
Brian Zhao 556824565d Automated g4 rollback of changelist 304856650.
PiperOrigin-RevId: 305076580
Change-Id: I98886941dbfb25acd99d6ca2601eaee6dc657034
2020-04-06 11:29:58 -07:00

515 lines
22 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// XLA-specific Ops for 2D convolution.
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
namespace {
// Returns the expanded size of a filter used for depthwise convolution.
// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
xla::Shape GroupedFilterShapeForDepthwiseConvolution(
const xla::Shape& filter_shape) {
int64 input_feature_dim = filter_shape.dimensions_size() - 2;
int64 output_feature_dim = filter_shape.dimensions_size() - 1;
int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
int64 input_feature = filter_shape.dimensions(input_feature_dim);
// Create a [H, W, ..., 1, N*M] reshape of the filter.
xla::Shape grouped_filter_shape = filter_shape;
grouped_filter_shape.set_dimensions(input_feature_dim, 1);
grouped_filter_shape.set_dimensions(output_feature_dim,
depthwise_multiplier * input_feature);
return grouped_filter_shape;
}
// Returns the transposed filter for use in BackpropInput of group convolution.
xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput(
xla::XlaOp filter, const xla::Shape& filter_shape, int64 num_groups,
int num_spatial_dims) {
// 1. Reshape from [H, W, ..., filter_in_depth, out_depth] to [H, W, ...,
// filter_in_depth, G, out_depth / G]
int num_dims = filter_shape.dimensions_size();
CHECK_GE(num_dims, 2); // Crash OK
xla::Shape new_shape = filter_shape;
new_shape.set_dimensions(num_dims - 1, num_groups);
new_shape.add_dimensions(filter_shape.dimensions(num_dims - 1) / num_groups);
xla::XlaOp result = xla::Reshape(filter, new_shape.dimensions());
// 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G]
std::vector<int64> transpose_dims(num_dims + 1);
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
std::swap(transpose_dims[num_spatial_dims],
transpose_dims[num_spatial_dims + 1]);
result = xla::Transpose(result, transpose_dims);
// 3. Reshape to [H, W, ..., in_depth, out_depth / G]
result = xla::Collapse(result, {num_spatial_dims, num_spatial_dims + 1});
return result;
}
// Returns the transposed input for use in BackpropFilter of group convolution.
xla::XlaOp TransposeInputForGroupConvolutionBackpropFilter(
xla::XlaOp input, const xla::Shape& input_shape, int64 num_groups,
int batch_dim, int depth_dim) {
// 1. Reshape the depth_dim C into [G, C/G]
int num_dims = input_shape.dimensions_size();
std::vector<int64> reshape_dims = xla::SpanToVector(input_shape.dimensions());
reshape_dims[depth_dim] = reshape_dims[depth_dim] / num_groups;
reshape_dims.insert(reshape_dims.begin() + depth_dim, num_groups);
xla::XlaOp result = xla::Reshape(input, reshape_dims);
// 2. Transpose G to the axis before N, e.g.: [G, N, H, W, C/G]
std::vector<int64> transpose_dims(num_dims + 1);
std::iota(transpose_dims.begin(), transpose_dims.end(),
0); // e.g.: [0, 1, 2, 3, 4] -> [N, H, W, G, C/G]
transpose_dims.erase(transpose_dims.begin() + depth_dim);
transpose_dims.insert(
transpose_dims.begin() + batch_dim,
depth_dim); // e.g.: [3, 0, 1, 2, 4] -> [G, N, H, W, C/G]
result = xla::Transpose(result, transpose_dims);
// 3. Merge [G, N] to [G*N]
result = xla::Collapse(result, {batch_dim, batch_dim + 1});
return result;
}
// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
// build a depthwise convolution.
xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
xla::XlaOp filter) {
return xla::Reshape(
filter,
GroupedFilterShapeForDepthwiseConvolution(filter_shape).dimensions());
}
// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
// convolutions (as currently implemented).
Status CheckConvAttrs(const ConvOpAttrs& attrs) {
const int num_dims = attrs.num_spatial_dims + 2;
if (attrs.strides.size() != num_dims) {
return errors::InvalidArgument("Sliding window strides field must specify ",
num_dims, " dimensions");
}
int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
return errors::Unimplemented(
"Current implementation does not yet support strides in the batch and "
"depth dimensions.");
}
if (attrs.dilations.size() != num_dims) {
return errors::InvalidArgument("Dilations field must specify ", num_dims,
" dimensions");
}
if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
return errors::Unimplemented(
"Current implementation does not support dilations in the batch and "
"depth dimensions.");
}
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
if (attrs.dilations[input_dim] < 1) {
return errors::Unimplemented("Dilation values must be positive; ", i,
"th spatial dimension had dilation ",
attrs.dilations[input_dim]);
}
}
return Status::OK();
}
// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
// to TensorShapes.
Status ConvBackpropComputeDimensionsV2XlaShapes(
StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
absl::Span<const int32> dilations, const std::vector<int32>& strides,
Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims,
absl::Span<const int64> explicit_paddings) {
TensorShape input_tensor_shape, filter_tensor_shape,
out_backprop_tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
return ConvBackpropComputeDimensionsV2(
label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings,
data_format, dims);
}
} // anonymous namespace
absl::Span<const DataType> GetXlaConvTypes() {
return {DT_FLOAT, DT_BFLOAT16, DT_HALF, DT_DOUBLE};
}
xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
bool depthwise,
OpKernelConstruction* ctx) {
ConvOpAttrs attrs;
attrs.num_spatial_dims = num_spatial_dims;
attrs.depthwise = depthwise;
TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
if (attrs.padding == EXPLICIT) {
TF_RETURN_IF_ERROR(
ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
}
string data_format;
TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
if (!FormatFromString(data_format, &attrs.data_format)) {
return errors::InvalidArgument("Invalid data format: ", data_format);
}
return attrs;
}
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter,
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
auto* builder = conv_input.builder();
TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
// Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
// For 2D convolution, there should be 4 dimensions.
int num_dims = attrs.num_spatial_dims + 2;
if (input_shape.dimensions_size() != num_dims) {
return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
input_shape.DebugString());
}
if (filter_shape.dimensions_size() != num_dims) {
return errors::InvalidArgument(
"filter must be ", num_dims,
"-dimensional: ", filter_shape.DebugString());
}
// The last two dimensions of the filter are the input and output shapes.
int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
int64 filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
out_depth = filter_shape.dimensions(attrs.num_spatial_dims + 1),
in_depth = input_shape.dimensions(feature_dim);
// The 'C' dimension for input is in_depth.
// It must be a multiple of the filter's in_depth.
if (in_depth % filter_in_depth != 0) {
return errors::InvalidArgument(
"Depth of input must be a multiple of depth of filter: ", in_depth,
" vs ", filter_in_depth);
}
int64 feature_group_count = in_depth / filter_in_depth;
if (out_depth % feature_group_count != 0) {
return errors::InvalidArgument(
"Depth of output must be a multiple of the number of groups: ",
out_depth, " vs ", feature_group_count);
}
if (attrs.depthwise) {
filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
}
xla::ConvolutionDimensionNumbers dims;
std::vector<int64> window_strides(attrs.num_spatial_dims);
std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
dims.set_input_batch_dimension(batch_dim);
dims.set_output_batch_dimension(batch_dim);
dims.set_input_feature_dimension(feature_dim);
dims.set_output_feature_dimension(feature_dim);
dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
dims.add_input_spatial_dimensions(dim);
dims.add_kernel_spatial_dimensions(i);
dims.add_output_spatial_dimensions(dim);
window_strides[i] = attrs.strides.at(dim);
rhs_dilation[i] = attrs.dilations.at(dim);
if (attrs.padding == EXPLICIT) {
padding[i] = {attrs.explicit_paddings.at(dim * 2),
attrs.explicit_paddings.at(dim * 2 + 1)};
}
int64 unused_output_size;
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
input_shape.dimensions(dim), filter_shape.dimensions(i),
rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
&padding[i].first, &padding[i].second));
}
return xla::ConvGeneralDilated(
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
dims,
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count,
/*batch_group_count=*/1, precision_config);
}
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
int num_dims = attrs.num_spatial_dims + 2;
int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
auto* builder = filter.builder();
TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
builder->GetShape(out_backprop));
int64 in_depth = input_shape.dimensions(feature_dim),
filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
feature_group_count =
attrs.depthwise ? filter_in_depth : in_depth / filter_in_depth;
xla::Shape grouped_filter_shape =
attrs.depthwise ? GroupedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
ConvBackpropDimensions dims;
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
type_string, attrs.num_spatial_dims, input_shape, grouped_filter_shape,
out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
attrs.data_format, &dims, attrs.explicit_paddings));
// The input gradients are computed by a convolution of the output
// gradients and the filter, with some appropriate padding. See the
// comment at the top of conv_grad_shape_utils.h for details.
xla::ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(batch_dim);
dnums.set_output_batch_dimension(batch_dim);
dnums.set_input_feature_dimension(feature_dim);
dnums.set_output_feature_dimension(feature_dim);
// TF filter shape is [ H, W, ..., inC, outC ]
// Transpose the input and output features for computing the gradient.
dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
std::vector<int64> ones(attrs.num_spatial_dims, 1);
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(i);
dnums.add_output_spatial_dimensions(dim);
kernel_spatial_dims[i] = i;
padding[i] = {dims.spatial_dims[i].pad_before,
dims.spatial_dims[i].pad_after};
lhs_dilation[i] = dims.spatial_dims[i].stride;
rhs_dilation[i] = attrs.dilations[dim];
}
if (feature_group_count != 1 && !attrs.depthwise) {
filter = TransposeFilterForGroupConvolutionBackpropInput(
filter, filter_shape, feature_group_count, attrs.num_spatial_dims);
}
// Mirror the filter in the spatial dimensions.
filter = xla::Rev(filter, kernel_spatial_dims);
// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
return xla::ConvGeneralDilated(out_backprop, filter, /*window_strides=*/ones,
padding, lhs_dilation, rhs_dilation, dnums,
/*feature_group_count=*/
feature_group_count,
/*batch_group_count=*/1, precision_config);
}
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
StringPiece type_string, xla::XlaOp activations,
const xla::Shape& filter_shape, xla::XlaOp gradients,
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
auto* builder = activations.builder();
TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
builder->GetShape(activations));
TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
builder->GetShape(gradients));
xla::XlaOp filter_backprop;
xla::Shape input_shape = activations_shape;
xla::Shape output_shape = out_backprop_shape;
TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape));
const xla::Shape grouped_filter_shape =
attrs.depthwise ? GroupedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
ConvBackpropDimensions dims;
// The filter gradients are computed by a convolution of the input
// activations and the output gradients, with some appropriate padding.
// See the comment at the top of conv_grad_shape_utils.h for details.
xla::ConvolutionDimensionNumbers dnums;
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
type_string, attrs.num_spatial_dims, activations_shape,
grouped_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings));
// Obtain some useful dimensions:
// The last two dimensions of the filter are the input and output shapes.
int num_dims = attrs.num_spatial_dims + 2;
int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
int64 in_depth = input_shape.dimensions(c_dim),
filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
batch_group_count =
attrs.depthwise ? filter_in_depth : in_depth / filter_in_depth;
std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
std::vector<int64> window_strides(attrs.num_spatial_dims);
std::vector<int64> ones(attrs.num_spatial_dims, 1);
// Swap n_dim and c_dim in the activations.
dnums.set_input_batch_dimension(c_dim);
dnums.set_input_feature_dimension(n_dim);
// The gradients become the RHS of the convolution.
// The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
// where the batch becomes the input feature for the convolution.
dnums.set_kernel_input_feature_dimension(n_dim);
dnums.set_kernel_output_feature_dimension(c_dim);
dnums.set_output_batch_dimension(attrs.num_spatial_dims);
dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
// Tensorflow filter shape is [ H, W, ..., inC, outC ].
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
dnums.add_output_spatial_dimensions(i);
}
for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(dim);
rhs_dilation[i] = dims.spatial_dims[i].stride;
window_strides[i] = attrs.dilations[dim];
// We will also need to pad the input with zeros such that after the
// convolution, we get the right size for the filter.
// The padded_in_rows should be such that when we convolve this with the
// expanded_out_rows as a filter, we should get filter_rows back.
const int64 padded_in_size =
dims.spatial_dims[i].expanded_output_size +
(dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
// However it can be smaller than input_rows: in this
// case it means some of the inputs are not used.
//
// An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
//
// INPUT = [ A B C ]
//
// FILTER = [ x y ]
//
// and the output will only have one column: a = A * x + B * y
//
// and input "C" is not used at all.
//
// We apply negative padding in this case.
const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
// + For the EXPLICIT padding, we pad the top/left side with the explicit
// padding and pad the bottom/right side with the remaining space.
// + For the VALID padding, we don't pad anything on the top/left side
// and pad the bottom/right side with the remaining space.
// + For the SAME padding, we pad top/left side the same as bottom/right
// side.
//
// In addition, if the padded input size is smaller than the input size,
// we need to ignore some training elements of the input. We do this by
// applying negative padding on the right/bottom.
const int64 pad_before = attrs.padding == Padding::EXPLICIT
? attrs.explicit_paddings[2 * dim]
: attrs.padding == Padding::SAME
? std::max<int64>(pad_total / 2, 0)
: 0;
padding[i] = {pad_before, pad_total - pad_before};
}
// Besides padding the input, we will also expand output_rows to
// expanded_out_rows = (output_rows - 1) * stride + 1
// with zeros in between:
//
// a . . . b . . . c . . . d . . . e
//
// This is done by specifying the window dilation factors in the
// convolution HLO below.
filter_backprop = xla::ConvGeneralDilated(
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
rhs_dilation, dnums,
/*feature_group_count=*/1,
/*batch_group_count=*/batch_group_count, precision_config);
if (attrs.depthwise) {
filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions());
}
return filter_backprop;
}
} // namespace tensorflow