Support dynamic spatial dimensions in convolution.

- Lower dynamic convolutions into kCustomCalls.
- Transform those kCustomCalls into static convolutions in dynamic padder.

PiperOrigin-RevId: 339275495
Change-Id: I0e1a6c0ff7f539e63482f1de7d564dca23ab81bc
This commit is contained in:
Yunxing Dai 2020-10-27 10:19:30 -07:00 committed by TensorFlower Gardener
parent 9c31fe8ef6
commit 292fa95c75
21 changed files with 1083 additions and 30 deletions

View File

@ -269,9 +269,19 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
dims.set_output_feature_dimension(feature_dim);
dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
if (input_shape.is_dynamic_dimension(dim)) {
TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
<< "Dynamic convolution only supports valid and same padding";
if (attrs.padding == VALID) {
padding_type = xla::PaddingType::PADDING_VALID;
}
if (attrs.padding == SAME) {
padding_type = xla::PaddingType::PADDING_SAME;
}
}
dims.add_input_spatial_dimensions(dim);
dims.add_kernel_spatial_dimensions(i);
dims.add_output_spatial_dimensions(dim);
@ -290,6 +300,15 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
&padding[i].first, &padding[i].second));
}
if (padding_type != xla::PaddingType::PADDING_INVALID) {
return xla::DynamicConvForward(
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
dims,
/*feature_group_count=*/attrs.depthwise ? in_depth
: feature_group_count,
/*batch_group_count=*/1, precision_config, padding_type);
}
return xla::ConvGeneralDilated(
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
dims,
@ -300,7 +319,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config) {
const xla::PrecisionConfig* precision_config, xla::XlaOp* input_sizes) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
int num_dims = attrs.num_spatial_dims + 2;
@ -347,8 +366,19 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
std::vector<int64> ones(attrs.num_spatial_dims, 1);
xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
if (out_backprop_shape.is_dynamic_dimension(dim)) {
TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
<< "Dynamic convolution only supports valid and same padding";
if (attrs.padding == VALID) {
padding_type = xla::PaddingType::PADDING_VALID;
}
if (attrs.padding == SAME) {
padding_type = xla::PaddingType::PADDING_SAME;
}
}
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(i);
dnums.add_output_spatial_dimensions(dim);
@ -366,7 +396,15 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
}
// Mirror the filter in the spatial dimensions.
filter = xla::Rev(filter, kernel_spatial_dims);
if (padding_type != xla::PaddingType::PADDING_INVALID) {
TF_RET_CHECK(input_sizes != nullptr);
return xla::DynamicConvInputGrad(
*input_sizes, out_backprop, filter, /*window_strides=*/ones, padding,
lhs_dilation, rhs_dilation, dnums,
/*feature_group_count=*/
feature_group_count,
/*batch_group_count=*/1, precision_config, padding_type);
}
// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
return xla::ConvGeneralDilated(out_backprop, filter, /*window_strides=*/ones,
@ -444,9 +482,19 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
dnums.add_output_spatial_dimensions(i);
}
xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
if (activations_shape.is_dynamic_dimension(dim)) {
TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
<< "Dynamic convolution only supports valid and same padding";
if (attrs.padding == VALID) {
padding_type = xla::PaddingType::PADDING_VALID;
}
if (attrs.padding == SAME) {
padding_type = xla::PaddingType::PADDING_SAME;
}
}
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(dim);
rhs_dilation[i] = dims.spatial_dims[i].stride;
@ -503,12 +551,20 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
//
// This is done by specifying the window dilation factors in the
// convolution HLO below.
filter_backprop = xla::ConvGeneralDilated(
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
rhs_dilation, dnums,
/*feature_group_count=*/1,
/*batch_group_count=*/batch_group_count, precision_config);
if (padding_type != xla::PaddingType::PADDING_INVALID) {
filter_backprop = xla::DynamicConvKernelGrad(
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
rhs_dilation, dnums,
/*feature_group_count=*/1,
/*batch_group_count=*/batch_group_count, precision_config,
padding_type);
} else {
filter_backprop = xla::ConvGeneralDilated(
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
rhs_dilation, dnums,
/*feature_group_count=*/1,
/*batch_group_count=*/batch_group_count, precision_config);
}
if (attrs.depthwise) {
filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions());

View File

@ -64,7 +64,8 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config = nullptr);
const xla::PrecisionConfig* precision_config = nullptr,
xla::XlaOp* input_sizes = nullptr);
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
StringPiece type_string, xla::XlaOp activations,
const xla::Shape& filter_shape, xla::XlaOp gradients,

View File

@ -112,10 +112,10 @@ class ConvBackpropInputOp : public XlaOpKernel {
"The rank of the specified input shape must be "
"num_spatial_dims + 2. Expected ",
attrs_.num_spatial_dims + 2, " got ", input_shape.rank()));
xla::StatusOr<xla::XlaOp> in_backprop =
MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
ctx->Input(1), ctx->Input(2), attrs_);
xla::XlaOp input_sizes = ctx->Input(0);
xla::StatusOr<xla::XlaOp> in_backprop = MakeXlaBackpropInputConvOp(
ctx->op_kernel().type_string(), input_shape, ctx->Input(1),
ctx->Input(2), attrs_, nullptr, &input_sizes);
OP_REQUIRES_OK(ctx, in_backprop.status());
ctx->SetOutput(0, in_backprop.ValueOrDie());
}

View File

@ -307,7 +307,8 @@ REGISTER_OP("XlaSetDynamicDimensionSize")
.Input("size: int32")
.Output("output: T")
.Attr("T: type")
.SetShapeFn(shape_inference::UnchangedShape)
// Use unknown shape to prevent constant folding.
.SetShapeFn(shape_inference::UnknownShape)
.Doc(
R"doc(Make a static dimension into a xla bounded dynamic dimension.
The current static dimension size will become the bound and the second

View File

@ -1439,6 +1439,110 @@ XlaOp XlaBuilder::ConvGeneralDilated(
});
}
StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
std::vector<int64> window_dimensions(
dimension_numbers.kernel_spatial_dimensions_size());
for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
window_dimensions[i] =
rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
}
TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides,
padding, lhs_dilation, rhs_dilation));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferConvolveShape(
*lhs_shape, *rhs_shape, feature_group_count,
batch_group_count, window, dimension_numbers));
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_window() = window;
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
instr.set_batch_group_count(batch_group_count);
instr.set_padding_type(padding_type);
if (precision_config != nullptr) {
*instr.mutable_precision_config() = *precision_config;
}
return std::move(instr);
}
XlaOp XlaBuilder::DynamicConvInputGrad(
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision_config, padding_type));
instr.set_custom_call_target("DynamicConvolutionInputGrad");
return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
{input_sizes, lhs, rhs});
});
}
XlaOp XlaBuilder::DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
DynamicConvInstruction(activations, gradients, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision_config, padding_type));
instr.set_custom_call_target("DynamicConvolutionKernelGrad");
return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
{activations, gradients});
});
}
XlaOp XlaBuilder::DynamicConvForward(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision_config, padding_type));
instr.set_custom_call_target("DynamicConvolutionForward");
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
});
}
StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
absl::Span<const int64> window_strides,
@ -3901,6 +4005,49 @@ XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
precision_config);
}
XlaOp DynamicConvInputGrad(XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type) {
return lhs.builder()->DynamicConvInputGrad(
input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type);
}
XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
return activations.builder()->DynamicConvKernelGrad(
activations, gradients, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type);
}
XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type) {
return lhs.builder()->DynamicConvForward(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type);
}
XlaOp Fft(const XlaOp operand, FftType fft_type,
absl::Span<const int64> fft_length) {
return operand.builder()->Fft(operand, fft_type, fft_length);

View File

@ -560,6 +560,45 @@ class XlaBuilder {
int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type);
XlaOp DynamicConvInputGrad(
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
StatusOr<HloInstructionProto> DynamicConvInstruction(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
absl::Span<const int64> window_strides,
@ -1057,13 +1096,49 @@ class XlaBuilder {
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config);
const PrecisionConfig* precision_confige);
friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config);
friend XlaOp DynamicConvForward(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
friend XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
friend XlaOp DynamicConvInputGrad(
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
friend XlaOp ConvKernelGrad(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config);
friend XlaOp ConvGeneralDilated(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
@ -1761,6 +1836,34 @@ XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type);
XlaOp DynamicConvInputGrad(XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type);
XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span<const int64> fft_length);

View File

@ -2721,17 +2721,35 @@ cc_library(
],
)
cc_library(
name = "dynamic_window_utils",
srcs = ["dynamic_window_utils.cc"],
hdrs = ["dynamic_window_utils.h"],
deps = [
":hlo",
":shape_inference",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/core/platform:macros",
],
)
cc_library(
name = "dynamic_dimension_inference",
srcs = ["dynamic_dimension_inference.cc"],
hdrs = ["dynamic_dimension_inference.h"],
deps = [
":dynamic_window_utils",
":hlo",
":hlo_casting_utils",
":tuple_util",
":while_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -2751,10 +2769,13 @@ cc_library(
hdrs = ["dynamic_padder.h"],
deps = [
":dynamic_dimension_inference",
":dynamic_window_utils",
":hlo",
":hlo_casting_utils",
":hlo_creation_utils",
":hlo_dce",
":hlo_pass",
":hlo_verifier",
":shape_inference",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal",
@ -2762,6 +2783,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -29,10 +30,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
namespace xla {
namespace {
@ -157,6 +158,18 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
using DynamicDimensionFn = std::function<Status(
ShapeIndex index, int64 dimension, HloInstruction* dynamic_size)>;
Status HandleDynamicConvolutionForward(HloInstruction* hlo,
int64 operand_index, int64 dimension,
HloInstruction* dynamic_size);
Status HandleDynamicConvolutionKernelGrad(HloInstruction* hlo,
int64 operand_index,
int64 dimension);
Status HandleDynamicConvolutionInputGrad(HloInstruction* hlo,
int64 operand_index,
int64 dimension);
Status ForEachOperandDynamicDimension(HloInstruction* inst,
const OperandDynamicDimensionFn&);
Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
@ -256,6 +269,20 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
return Status::OK();
}
if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
return HandleDynamicConvolutionInputGrad(hlo, operand_index,
dimension);
}
if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
dimension);
}
if (hlo->custom_call_target() == "DynamicConvolutionForward") {
return HandleDynamicConvolutionForward(hlo, operand_index, dimension,
dynamic_size);
}
return Unimplemented(
"CustomCall \"%s\" is not supported to have a dynamic dimension",
hlo->custom_call_target());
@ -591,6 +618,70 @@ Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
return Status::OK();
}
Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
HloInstruction* hlo, int64 operand_index, int64 dimension,
HloInstruction* dynamic_size) {
TF_RET_CHECK(operand_index == 0);
const ConvolutionDimensionNumbers& dimension_numbers =
hlo->convolution_dimension_numbers();
if (dimension == dimension_numbers.input_batch_dimension()) {
// Batch dimension is propagated without any changes.
parent_->SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
dynamic_size);
return Status::OK();
}
for (int64 spatial_dim_index = 0;
spatial_dim_index < dimension_numbers.input_spatial_dimensions_size();
++spatial_dim_index) {
int64 input_spatial_dim =
dimension_numbers.input_spatial_dimensions(spatial_dim_index);
int64 output_spatial_dim =
dimension_numbers.output_spatial_dimensions(spatial_dim_index);
if (dimension == input_spatial_dim) {
// This is a dynamic spatial dimension. Calculate the output size.
WindowDimension window_dim = hlo->window().dimensions(spatial_dim_index);
DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
dynamic_size, window_dim.size(), window_dim.window_dilation(),
window_dim.stride(), hlo->padding_type());
TF_RET_CHECK(window_dim.base_dilation() == 1);
parent_->SetDynamicSize(hlo, {}, output_spatial_dim,
dynamic_window_dims.output_size);
return Status::OK();
}
}
return Unimplemented(
"XLA doesn't support dynamic input feature dimension on convolution: %s",
hlo->ToString());
}
Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionInputGrad(
HloInstruction* hlo, int64 operand_index, int64 dimension) {
// The output size of convolution input grad is corresponding input size.
HloInstruction* input_sizes = hlo->mutable_operand(0);
HloComputation* comp = hlo->parent();
TF_RET_CHECK(input_sizes->shape().rank() == 1) << hlo->ToString();
TF_RET_CHECK(input_sizes->shape().element_type() == S32) << hlo->ToString();
TF_RET_CHECK(input_sizes->shape().dimensions(0) ==
hlo->shape().dimensions_size())
<< hlo->ToString();
// Slice to get corresponding input size.
HloInstruction* slice = comp->AddInstruction(
HloInstruction::CreateSlice(ShapeUtil::MakeShape(S32, {1}), input_sizes,
{dimension}, {dimension + 1}, {1}));
HloInstruction* reshape = comp->AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
parent_->SetDynamicSize(hlo, {}, dimension, reshape);
return Status::OK();
}
Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionKernelGrad(
HloInstruction* hlo, int64 operand_index, int64 dimension) {
// Dynamic convolution kernel grad produces static shape outputs.
return Status::OK();
}
Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
@ -1540,4 +1631,15 @@ HloInstruction* DynamicDimensionInference::GetDynamicSize(
return nullptr;
}
std::vector<HloInstruction*> DynamicDimensionInference::GetDynamicSizes(
HloInstruction* inst, const ShapeIndex& index) const {
CHECK(ShapeUtil::IndexIsValid(inst->shape(), index));
const int64 rank = ShapeUtil::GetSubshape(inst->shape(), index).rank();
std::vector<HloInstruction*> result(rank, nullptr);
for (int64 i = 0; i < rank; ++i) {
result[i] = GetDynamicSize(inst, {}, i);
}
return result;
}
} // namespace xla

View File

@ -51,8 +51,13 @@ class DynamicDimensionInference {
HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
int64 dim) const;
// Returns if current instruction contains any dynamic dimension. Recursively
// go into tuples.
// Returns dynamic sizes of all dimensions of `inst`'s leaf node at `index`.
// Static sizes are represented by nullptr.
std::vector<HloInstruction*> GetDynamicSizes(HloInstruction* inst,
const ShapeIndex& index) const;
// Returns if current instruction contains any dynamic dimension.
// Recursively go into tuples.
bool HasDynamicDimension(HloInstruction* inst) const;
// Forward dynamic dimension size at `dim` from `inst` to `new_inst`.

View File

@ -27,8 +27,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
#include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
@ -37,6 +39,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
@ -122,9 +125,9 @@ StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
case HloOpcode::kSlice:
case HloOpcode::kDomain:
return nullptr;
// Assume that custom calls created by the client are valid with padded
// dynamic dimensions.
case HloOpcode::kCustomCall:
// Assume that custom calls created by the client are valid with padded
// dynamic dimensions.
return nullptr;
default:
return UnimplementedStrCat("Unimplemented padding for instruction: ",
@ -721,6 +724,262 @@ Status RewriteDynamicReshapeSingleGroup(
return Status::OK();
}
HloInstruction* RewriteInputWithDynamicPadding(
HloInstruction* conv, HloInstruction* input,
absl::Span<HloInstruction*> padding_before, Window* input_window) {
HloComputation* comp = conv->parent();
auto dnums = conv->convolution_dimension_numbers();
HloInstruction* zero_s32 = comp->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
Shape padded_grad_shape = input->shape();
PaddingConfig padding_configs;
for (int64 i = 0; i < input->shape().rank(); ++i) {
PaddingConfig::PaddingConfigDimension padding_dim;
*padding_configs.add_dimensions() = padding_dim;
}
std::vector<HloInstruction*> start_indices(input->shape().rank(), zero_s32);
for (int64 spatial_dim_index = 0;
spatial_dim_index < dnums.input_spatial_dimensions_size();
++spatial_dim_index) {
int64 input_spatial_dim = dnums.input_spatial_dimensions(spatial_dim_index);
if (padding_before[spatial_dim_index] == nullptr) {
continue;
}
WindowDimension* window_dim =
input_window->mutable_dimensions(spatial_dim_index);
auto* padding_dim = padding_configs.mutable_dimensions(input_spatial_dim);
const int64 dilated_window_size = window_util::DilatedBound(
window_dim->size(), window_dim->window_dilation());
// Chosoe dilated window size as low padding and static padding_high +
// padding_low as high padding to make sure the following dynamic slice is
// valid.
//
// See go/xla-dynamic-spatial-dim for more details.
padding_dim->set_edge_padding_low(dilated_window_size);
padding_dim->set_edge_padding_high(window_dim->padding_high() +
window_dim->padding_low());
padding_dim->set_interior_padding(window_dim->base_dilation() - 1);
HloInstruction* slicing_start =
comp->AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract,
comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(padding_dim->edge_padding_low()))),
padding_before[spatial_dim_index]));
start_indices[input_spatial_dim] = slicing_start;
padded_grad_shape.mutable_dimensions()[input_spatial_dim] =
window_dim->padding_low() +
window_util::DilatedBound(
padded_grad_shape.dimensions(input_spatial_dim),
window_dim->base_dilation()) +
window_dim->padding_high();
window_dim->clear_padding_high();
window_dim->clear_padding_low();
window_dim->set_base_dilation(1);
input->mutable_shape()->set_dynamic_dimension(input_spatial_dim, false);
}
// Reconstruct dynamic padding using pad and dynamic slice.
HloInstruction* pad =
MakePadHlo(input,
comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(conv->shape().element_type()))),
padding_configs)
.ValueOrDie();
input = comp->AddInstruction(HloInstruction::CreateDynamicSlice(
padded_grad_shape, pad, start_indices, padded_grad_shape.dimensions()));
return input;
}
StatusOr<bool> RewriteDynamicConvolutionInputGrad(
HloInstruction* custom_call_conv,
DynamicDimensionInference* dynamic_dimension_inference) {
HloInstruction* grad = custom_call_conv->mutable_operand(1);
HloInstruction* kernel = custom_call_conv->mutable_operand(2);
TF_RET_CHECK(kernel->shape().is_static());
auto dnums = custom_call_conv->convolution_dimension_numbers();
HloComputation* comp = custom_call_conv->parent();
Window window = custom_call_conv->window();
HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(custom_call_conv->shape().element_type())));
std::vector<HloInstruction*> padding_before(
dnums.input_spatial_dimensions_size(), nullptr);
for (int64 spatial_dim_index = 0;
spatial_dim_index < dnums.input_spatial_dimensions_size();
++spatial_dim_index) {
int64 input_spatial_dim = dnums.input_spatial_dimensions(spatial_dim_index);
HloInstruction* operand_dynamic_size =
dynamic_dimension_inference->GetDynamicSize(
custom_call_conv->mutable_operand(1), {}, input_spatial_dim);
if (operand_dynamic_size == nullptr) {
continue;
}
grad = PadWithScalar(grad, input_spatial_dim, operand_dynamic_size, zero);
HloInstruction* slice = comp->AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(S32, {1}), custom_call_conv->mutable_operand(0),
{input_spatial_dim}, {input_spatial_dim + 1}, {1}));
HloInstruction* dynamic_input_size = comp->AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
// Window stride of forward prop is same as base dilation of backward prop.
DynamicWindowDims dynamic_window_dims = GetWindowedInputGradSize(
dynamic_input_size, /*window_size=*/window_dim.size(),
/*window_dilation=*/window_dim.window_dilation(),
/*window_stride=*/window_dim.base_dilation(),
custom_call_conv->padding_type());
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
}
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
grad = RewriteInputWithDynamicPadding(
custom_call_conv, grad, absl::MakeSpan(padding_before), &window);
}
PrecisionConfig precision_config;
if (custom_call_conv->precision_config().operand_precision_size() == 3) {
// We are not interested in the precision config of the first operand, which
// is the input_sizes.
*precision_config.mutable_operand_precision() = {
custom_call_conv->precision_config().operand_precision().begin() + 1,
custom_call_conv->precision_config().operand_precision().end()};
}
HloInstruction* static_conv = comp->AddInstruction(
HloInstruction::CreateConvolve(
custom_call_conv->shape(), grad, kernel,
custom_call_conv->feature_group_count(),
custom_call_conv->batch_group_count(), window,
custom_call_conv->convolution_dimension_numbers(),
custom_call_conv->precision_config()),
"ConvBackwardInput");
TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
custom_call_conv, static_conv, {}));
return true;
}
StatusOr<bool> RewriteDynamicConvolutionForward(
HloInstruction* custom_call_conv,
DynamicDimensionInference* dynamic_dimension_inference) {
HloInstruction* input = custom_call_conv->mutable_operand(0);
HloInstruction* kernel = custom_call_conv->mutable_operand(1);
TF_RET_CHECK(kernel->shape().is_static());
TF_RET_CHECK(input->shape().is_dynamic());
HloComputation* comp = custom_call_conv->parent();
Window window = custom_call_conv->window();
auto dnums = custom_call_conv->convolution_dimension_numbers();
HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(custom_call_conv->shape().element_type())));
std::vector<HloInstruction*> padding_before(
dnums.input_spatial_dimensions_size(), nullptr);
for (int64 spatial_dim_index = 0;
spatial_dim_index < dnums.input_spatial_dimensions_size();
++spatial_dim_index) {
int64 input_spatial_dim = dnums.input_spatial_dimensions(spatial_dim_index);
HloInstruction* operand_dynamic_size =
dynamic_dimension_inference->GetDynamicSize(
custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
if (operand_dynamic_size == nullptr) {
continue;
}
input = PadWithScalar(input, input_spatial_dim, operand_dynamic_size, zero);
const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
window_dim.stride(), custom_call_conv->padding_type());
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
}
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
input = RewriteInputWithDynamicPadding(
custom_call_conv, input, absl::MakeSpan(padding_before), &window);
}
HloInstruction* static_conv = comp->AddInstruction(
HloInstruction::CreateConvolve(
custom_call_conv->shape(), input, kernel,
custom_call_conv->feature_group_count(),
custom_call_conv->batch_group_count(), window,
custom_call_conv->convolution_dimension_numbers(),
custom_call_conv->precision_config()),
"ConvForward");
TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
custom_call_conv, static_conv, {}));
return true;
}
StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
HloInstruction* custom_call_conv,
DynamicDimensionInference* dynamic_dimension_inference) {
HloInstruction* activations = custom_call_conv->mutable_operand(0);
HloInstruction* gradients = custom_call_conv->mutable_operand(1);
TF_RET_CHECK(activations->shape().is_dynamic());
TF_RET_CHECK(gradients->shape().is_dynamic());
HloComputation* comp = custom_call_conv->parent();
Window window = custom_call_conv->window();
auto dnums = custom_call_conv->convolution_dimension_numbers();
HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(custom_call_conv->shape().element_type())));
std::vector<HloInstruction*> padding_before(
dnums.input_spatial_dimensions_size(), nullptr);
for (int64 spatial_dim_index = 0;
spatial_dim_index < dnums.input_spatial_dimensions_size();
++spatial_dim_index) {
int64 input_spatial_dim = dnums.input_spatial_dimensions(spatial_dim_index);
int64 kernel_spatial_dim =
dnums.kernel_spatial_dimensions(spatial_dim_index);
HloInstruction* activations_dynamic_size =
dynamic_dimension_inference->GetDynamicSize(
custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
if (activations_dynamic_size != nullptr) {
activations = PadWithScalar(activations, input_spatial_dim,
activations_dynamic_size, zero);
}
HloInstruction* gradients_dynamic_size =
dynamic_dimension_inference->GetDynamicSize(
custom_call_conv->mutable_operand(1), {}, kernel_spatial_dim);
if (gradients_dynamic_size != nullptr) {
gradients = PadWithScalar(gradients, kernel_spatial_dim,
gradients_dynamic_size, zero);
}
if (activations_dynamic_size == nullptr ||
gradients_dynamic_size == nullptr) {
TF_RET_CHECK(activations_dynamic_size == nullptr &&
gradients_dynamic_size == nullptr);
continue;
}
int64 output_spatial_dim =
dnums.output_spatial_dimensions(spatial_dim_index);
const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
activations_dynamic_size, /*window_size=*/
custom_call_conv->shape().dimensions(output_spatial_dim),
/*window_dilation=*/window_dim.stride(),
/*window_stride=*/window_dim.window_dilation(),
custom_call_conv->padding_type());
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
}
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
activations = RewriteInputWithDynamicPadding(
custom_call_conv, activations, absl::MakeSpan(padding_before), &window);
}
HloInstruction* static_conv = comp->AddInstruction(
HloInstruction::CreateConvolve(
custom_call_conv->shape(), activations, gradients,
custom_call_conv->feature_group_count(),
custom_call_conv->batch_group_count(), window,
custom_call_conv->convolution_dimension_numbers(),
custom_call_conv->precision_config()),
"ConvBackwardGrad");
TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
custom_call_conv, static_conv, {}));
return true;
}
StatusOr<bool> RewriteDynamicConcat(
HloInstruction* concat,
DynamicDimensionInference* dynamic_dimension_inference) {
@ -1323,6 +1582,23 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
inst, static_reshape, {}));
continue;
}
if (inst->IsCustomCall("DynamicConvolutionInputGrad")) {
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionInputGrad(
inst, &dynamic_dimension_inference));
continue;
}
if (inst->IsCustomCall("DynamicConvolutionForward")) {
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionForward(
inst, &dynamic_dimension_inference));
continue;
}
if (inst->IsCustomCall("DynamicConvolutionKernelGrad")) {
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionKernelGrad(
inst, &dynamic_dimension_inference));
continue;
}
for (int64 operand_num = 0; operand_num < inst->operand_count();
++operand_num) {
HloInstruction* original_operand = inst->mutable_operand(operand_num);
@ -1398,9 +1674,9 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
changed = changed || replaced_set_bound;
}
}
HloDCE dce;
TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
VLOG(2) << "Post DynamicPadder HLO:";
XLA_VLOG_LINES(2, module->ToString());
dynamic_padding_gauge->GetCell()->Set(changed);

View File

@ -0,0 +1,150 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
#include <string>
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
namespace xla {
namespace {
// HloOp wraps an instuction pointer to do arithmetic based on operator
// overloading.
//
// TODO(yunxing): This is only used internally to this file to provide a
// convenient way to do operator overloadding. Find out an idiom and merge this
// with hlo_creation_utils.
class HloOp {
public:
HloOp() = default;
explicit HloOp(HloInstruction* inst) : inst_(inst) {}
void SetName(const std::string& name) {
inst_->SetAndSanitizeName(name);
if (inst_->GetModule() != nullptr) {
inst_->UniquifyName(&inst_->GetModule()->instruction_name_uniquer());
}
}
HloInstruction* get() { return inst_; }
private:
HloInstruction* inst_ = nullptr;
};
HloOp BinaryOp(HloOp x, HloOp y, HloOpcode opcode,
const std::string& name = "") {
CHECK_EQ(x.get()->parent(), y.get()->parent());
Shape binary_op_shape =
ShapeInference::InferBinaryOpShape(opcode, x.get(), y.get()).ValueOrDie();
return HloOp(x.get()->parent()->AddInstruction(
HloInstruction::CreateBinary(binary_op_shape, opcode, x.get(), y.get()),
name));
}
HloOp operator+(HloOp x, HloOp y) { return BinaryOp(x, y, HloOpcode::kAdd); }
HloOp operator-(HloOp x, HloOp y) {
return BinaryOp(x, y, HloOpcode::kSubtract);
}
HloOp operator*(HloOp x, HloOp y) {
return BinaryOp(x, y, HloOpcode::kMultiply);
}
HloOp operator/(HloOp x, HloOp y) { return BinaryOp(x, y, HloOpcode::kDivide); }
HloOp Maximum(HloOp x, HloOp y, const std::string& name = "") {
return BinaryOp(x, y, HloOpcode::kMaximum, name);
}
template <typename NativeT>
HloOp ConstantR0(HloComputation* comp, NativeT value,
const std::string& name = "") {
return HloOp(comp->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)),
name));
}
template <typename NativeT>
HloOp One(HloComputation* comp) {
return ConstantR0<NativeT>(comp, 1, "one");
}
template <typename NativeT>
HloOp Zero(HloComputation* comp) {
return ConstantR0<NativeT>(comp, 0, "zero");
}
HloOp EffectiveFilterSize(HloComputation* comp, int64 window_size,
int64 window_dilation) {
return ConstantR0<int32>(comp, (window_size - 1) * window_dilation + 1,
"effective_filter_size");
}
} // namespace
DynamicWindowDims GetWindowedOutputSize(HloInstruction* input_size,
int64 window_size,
int64 window_dilation,
int64 window_stride,
PaddingType padding_type) {
HloComputation* comp = input_size->parent();
DynamicWindowDims result;
HloOp stride = ConstantR0<int32>(comp, window_stride, "stride");
HloOp effective_filter_size =
EffectiveFilterSize(comp, window_size, window_dilation);
if (padding_type == PaddingType::PADDING_VALID) {
HloOp output =
(HloOp(input_size) + stride - effective_filter_size) / stride;
result.output_size = output.get();
result.padding_before = Zero<int32>(comp).get();
} else if (padding_type == PaddingType::PADDING_SAME) {
HloOp output = (HloOp(input_size) + stride - One<int32>(comp)) / stride;
HloOp padding_needed = Maximum(
Zero<int32>(comp), (output - One<int32>(comp)) * stride +
effective_filter_size - HloOp(input_size));
HloOp padding_before = padding_needed / ConstantR0<int32>(comp, 2);
result.padding_before = padding_before.get();
result.output_size = output.get();
}
return result;
}
DynamicWindowDims GetWindowedInputGradSize(HloInstruction* input_size,
int64 window_size,
int64 window_dilation,
int64 window_stride,
PaddingType padding_type) {
HloComputation* comp = input_size->parent();
DynamicWindowDims result;
HloOp effective_filter_size =
ConstantR0<int32>(comp, (window_size - 1) * window_dilation + 1);
HloOp stride = ConstantR0<int32>(comp, window_stride);
DynamicWindowDims forward_dims = GetWindowedOutputSize(
input_size, window_size, window_dilation, window_stride, padding_type);
HloOp output_size =
(HloOp(forward_dims.output_size) - One<int32>(comp)) * stride +
One<int32>(comp);
HloOp padding_before = effective_filter_size - One<int32>(comp) -
HloOp(forward_dims.padding_before);
result.output_size = output_size.get();
result.padding_before = padding_before.get();
return result;
}
} // namespace xla

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_WINDOW_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_WINDOW_UTILS_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
struct DynamicWindowDims {
HloInstruction* padding_before;
HloInstruction* output_size;
};
// This mirrors the logic in GetWindowedOutputSizeVerboseV2 but with HLOs as
// inputs and outputs.
DynamicWindowDims GetWindowedOutputSize(HloInstruction* input_size,
int64 window_size,
int64 window_dilation,
int64 window_stride,
PaddingType padding_type);
DynamicWindowDims GetWindowedInputGradSize(HloInstruction* input_size,
int64 window_size,
int64 window_dilation,
int64 window_stride,
PaddingType padding_type);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_WINDOW_UTILS_H_

View File

@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
// Next ID: 75
// Next ID: 76
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@ -259,6 +259,9 @@ message HloInstructionProto {
// Specifies if this is a cross-program-prefetch, used by kCopyStart.
bool is_cross_program_prefetch = 73;
// If a convolution is dynamic, a dynamic padding type will be specified.
xla.PaddingType padding_type = 75;
}
// Serialization of HloComputation.

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -580,6 +581,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
custom_call_instr->set_custom_call_has_side_effect(
proto.custom_call_has_side_effect());
custom_call_instr->set_padding_type(proto.padding_type());
PrecisionConfig precision_config = proto.precision_config();
precision_config.mutable_operand_precision()->Resize(
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
*custom_call_instr->mutable_precision_config() = precision_config;
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_to_operand_aliasing;
for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
@ -2951,6 +2958,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
bool HloInstruction::IsCustomCall(absl::string_view target) const {
return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
}
bool HloInstruction::IsInputFusion() const {
return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
}
@ -3826,6 +3837,10 @@ const PrecisionConfig& HloInstruction::precision_config() const {
if (auto* dot = DynCast<HloDotInstruction>(this)) {
return dot->precision_config();
}
if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
return custom_call->precision_config();
}
LOG(FATAL) << "Unimplemented method.";
}
@ -4183,6 +4198,10 @@ const PaddingConfig& HloInstruction::padding_config() const {
return Cast<HloPadInstruction>(this)->padding_config();
}
PaddingType HloInstruction::padding_type() const {
return Cast<HloCustomCallInstruction>(this)->padding_type();
}
PaddingConfig* HloInstruction::mutable_padding_config() {
return Cast<HloPadInstruction>(this)->mutable_padding_config();
}

View File

@ -1366,6 +1366,8 @@ class HloInstruction {
// instruction.
bool IsFusible() const;
bool IsCustomCall(absl::string_view target) const;
// Returns the sharding applied to this operator.
// REQUIRES: has_sharding() is true.
const HloSharding& sharding() const {
@ -1843,6 +1845,9 @@ class HloInstruction {
const PaddingConfig& padding_config() const;
PaddingConfig* mutable_padding_config();
// Delegates to HloConvolutionInstruction::padding_type.
PaddingType padding_type() const;
// Delegates to HloDynamicSliceInstruction::slice_sizes.
int64 slice_sizes(int64 dimension) const;

View File

@ -2199,7 +2199,6 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
if (!precision_config_string.empty()) {
extra.push_back(precision_config_string);
}
return extra;
}
@ -2346,6 +2345,7 @@ HloCustomCallInstruction::HloCustomCallInstruction(
feature_group_count_(1),
batch_group_count_(1),
layout_constrained_(false),
padding_type_(PaddingType::PADDING_INVALID),
custom_call_has_side_effect_(false) {
set_raw_backend_config_string(std::move(opaque));
for (auto operand : operands) {
@ -2362,6 +2362,7 @@ HloCustomCallInstruction::HloCustomCallInstruction(
feature_group_count_(1),
batch_group_count_(1),
layout_constrained_(false),
padding_type_(PaddingType::PADDING_INVALID),
custom_call_has_side_effect_(false) {
set_raw_backend_config_string(std::move(opaque));
for (auto operand : operands) {
@ -2379,6 +2380,7 @@ HloCustomCallInstruction::HloCustomCallInstruction(
feature_group_count_(1),
batch_group_count_(1),
layout_constrained_(true),
padding_type_(PaddingType::PADDING_INVALID),
operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
operand_shapes_with_layout.end()),
custom_call_has_side_effect_(false) {
@ -2400,6 +2402,8 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
proto.set_custom_call_target(custom_call_target_);
proto.set_feature_group_count(feature_group_count_);
proto.set_batch_group_count(batch_group_count_);
*proto.mutable_precision_config() = precision_config_;
proto.set_padding_type(padding_type_);
if (layout_constrained()) {
proto.set_constrain_layout(true);
for (const Shape& shape : operand_shapes_with_layout_) {
@ -2437,6 +2441,13 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
if (batch_group_count_ != 1) {
extra.push_back(StrCat("batch_group_count=", batch_group_count_));
}
string precision_config_string = PrecisionConfigToString(precision_config_);
if (!precision_config_string.empty()) {
extra.push_back(precision_config_string);
}
if (padding_type_ != PaddingType::PADDING_INVALID) {
extra.push_back(StrCat("padding_type=", PaddingType_Name(padding_type())));
}
// By contract, we print the custom call target even if
// options.print_subcomputation_mode() == kOff, because the call target is not
// an HloComputation.
@ -2492,6 +2503,11 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
if (batch_group_count_ != casted_other.batch_group_count_) {
return false;
}
if (padding_type_ != casted_other.padding_type()) {
return false;
}
if (layout_constrained() != casted_other.layout_constrained()) {
return false;
}
@ -2511,6 +2527,10 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
casted_other.output_to_operand_aliasing()) {
return false;
}
if (!protobuf_util::ProtobufEquals(precision_config(),
casted_other.precision_config())) {
return false;
}
// Note: backend_config comparison is done in Identical, which is the
// intended/exposed way to compare computations, and so not repeated here.
return custom_call_target_ == casted_other.custom_call_target_;
@ -2536,6 +2556,8 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
cloned->set_batch_group_count(batch_group_count_);
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
cloned->set_padding_type(padding_type_);
*cloned->mutable_precision_config() = precision_config();
return std::move(cloned);
}

View File

@ -1453,6 +1453,16 @@ class HloCustomCallInstruction : public HloInstruction {
bool custom_call_has_side_effect() const {
return custom_call_has_side_effect_;
}
// Returns padding type used for ops like convolution.
PaddingType padding_type() const { return padding_type_; }
void set_padding_type(PaddingType padding_type) {
padding_type_ = padding_type;
}
const PrecisionConfig& precision_config() const { return precision_config_; }
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@ -1502,6 +1512,11 @@ class HloCustomCallInstruction : public HloInstruction {
int64 batch_group_count_;
// Whether the result and operand layouts are constrained.
bool layout_constrained_;
// Information used to communicate to the implementation about the algorithm
// used to produce results for convolution instructions.
PrecisionConfig precision_config_;
// Describes the padding type for convolution instructions.
PaddingType padding_type_;
// For layout-constrained custom calls, this vector holds the shape with
// layout for each operand.
std::vector<Shape> operand_shapes_with_layout_;

View File

@ -193,6 +193,7 @@ class HloParserImpl : public HloParser {
kHloComputation,
kBracedHloComputationList,
kFftType,
kPaddingType,
kComparisonDirection,
kComparisonType,
kWindow,
@ -328,6 +329,7 @@ class HloParserImpl : public HloParser {
bool ParseTiles(std::vector<Tile>* tiles);
bool ParseOpcode(HloOpcode* result);
bool ParseFftType(FftType* result);
bool ParsePaddingType(PaddingType* result);
bool ParseComparisonDirection(ComparisonDirection* result);
bool ParseComparisonType(Comparison::Type* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
@ -1838,6 +1840,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
optional<HloComputation*> to_apply;
optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
output_to_operand_aliasing;
optional<PaddingType> padding_type;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
@ -1856,6 +1859,12 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
attrs["output_to_operand_aliasing"] = {/*required=*/false,
AttrTy::kInstructionAliasing,
&output_to_operand_aliasing};
attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
&padding_type};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@ -1921,6 +1930,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
if (batch_group_count.has_value()) {
custom_call_instr->set_batch_group_count(*batch_group_count);
}
if (padding_type.has_value()) {
custom_call_instr->set_padding_type(*padding_type);
}
if (custom_call_has_side_effect.has_value()) {
custom_call_instr->set_custom_call_has_side_effect(
*custom_call_has_side_effect);
@ -1929,6 +1941,15 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
custom_call_instr->set_output_to_operand_aliasing(
std::move(*output_to_operand_aliasing));
}
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
operands.size(), PrecisionConfig::DEFAULT);
}
*custom_call_instr->mutable_precision_config() = precision_config;
break;
}
case HloOpcode::kDot: {
@ -3105,6 +3126,14 @@ bool HloParserImpl::ParseAttributeHelper(
static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kPaddingType: {
PaddingType result;
if (!ParsePaddingType(&result)) {
return false;
}
static_cast<optional<PaddingType>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kComparisonDirection: {
ComparisonDirection result;
if (!ParseComparisonDirection(&result)) {
@ -4246,6 +4275,19 @@ bool HloParserImpl::ParseFftType(FftType* result) {
return true;
}
bool HloParserImpl::ParsePaddingType(PaddingType* result) {
VLOG(3) << "ParsePaddingType";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects padding type");
}
std::string val = lexer_.GetStrVal();
if (!PaddingType_Parse(val, result) || !PaddingType_IsValid(*result)) {
return TokenError(StrFormat("expects padding type but sees: %s", val));
}
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
VLOG(3) << "ParseComparisonDirection";
if (lexer_.GetKind() != TokKind::kIdent) {

View File

@ -451,6 +451,20 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
}
)"
},
// convolution dynamic
{
"ConvolutionDynamic",
R"(HloModule Convolve1D1Window_0_module
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %custom-call.52 = f32[1,2,1]{2,0,1} custom-call(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}, custom_call_target="DynamicConvolutionForward", metadata={op_type="Conv2D" op_name="conv1d"}
}
)"
},
// convolution rank 2

View File

@ -1811,19 +1811,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// Input feature dimension is a contracting dimension, which does not
// affect the output dimension size. So we need to do nothing.
} else {
return InvalidArgument(
"Dynamic Spatial Convolution is not supported: lhs shape is %s ",
lhs.ToString());
for (int64 j = 0; j < dnums.output_spatial_dimensions_size(); ++j) {
if (i == dnums.input_spatial_dimensions(j)) {
// i is a spatial dimension, find corresponding output spatial
// dimension.
is_dynamic[dnums.output_spatial_dimensions(j)] = true;
}
}
}
}
if (rhs.is_dynamic_dimension(i)) {
if (i == dnums.kernel_input_feature_dimension()) {
// Kernel feature dimension does not affect the output dimension size.
// So we need to do nothing.
} else {
} else if (i == dnums.kernel_output_feature_dimension()) {
return InvalidArgument(
"Dynamic Spatial Convolution is not supported: rhs shape is %s ",
"Dynamic output feature dim on convolution kernel is not "
"supported: rhs shape is %s ",
rhs.ToString());
} else {
for (int64 j = 0; j < dnums.kernel_spatial_dimensions_size(); ++j) {
if (i == dnums.kernel_spatial_dimensions(j)) {
// i is a spatial dimension, find corresponding output spatial
// dimension.
is_dynamic[dnums.output_spatial_dimensions(j)] = true;
}
}
}
}
}

View File

@ -536,6 +536,12 @@ message ConvolutionDimensionNumbers {
// Next = 13
}
enum PaddingType {
PADDING_INVALID = 0;
PADDING_VALID = 1; // Only valid portion of the base are covered.
PADDING_SAME = 2; // Extra is added to produce same output size as the input.
}
enum FftType {
FFT = 0; // Forward FFT; complex in, complex out.
IFFT = 1; // Inverse FFT; complex in, complex out.