Support dynamic spatial dimensions in convolution.
- Lower dynamic convolutions into kCustomCalls. - Transform those kCustomCalls into static convolutions in dynamic padder. PiperOrigin-RevId: 339275495 Change-Id: I0e1a6c0ff7f539e63482f1de7d564dca23ab81bc
This commit is contained in:
parent
9c31fe8ef6
commit
292fa95c75
@ -269,9 +269,19 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
|
||||
dims.set_output_feature_dimension(feature_dim);
|
||||
dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
|
||||
dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
|
||||
|
||||
xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
|
||||
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
|
||||
const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
|
||||
if (input_shape.is_dynamic_dimension(dim)) {
|
||||
TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
|
||||
<< "Dynamic convolution only supports valid and same padding";
|
||||
if (attrs.padding == VALID) {
|
||||
padding_type = xla::PaddingType::PADDING_VALID;
|
||||
}
|
||||
if (attrs.padding == SAME) {
|
||||
padding_type = xla::PaddingType::PADDING_SAME;
|
||||
}
|
||||
}
|
||||
dims.add_input_spatial_dimensions(dim);
|
||||
dims.add_kernel_spatial_dimensions(i);
|
||||
dims.add_output_spatial_dimensions(dim);
|
||||
@ -290,6 +300,15 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
|
||||
&padding[i].first, &padding[i].second));
|
||||
}
|
||||
|
||||
if (padding_type != xla::PaddingType::PADDING_INVALID) {
|
||||
return xla::DynamicConvForward(
|
||||
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dims,
|
||||
/*feature_group_count=*/attrs.depthwise ? in_depth
|
||||
: feature_group_count,
|
||||
/*batch_group_count=*/1, precision_config, padding_type);
|
||||
}
|
||||
|
||||
return xla::ConvGeneralDilated(
|
||||
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dims,
|
||||
@ -300,7 +319,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config) {
|
||||
const xla::PrecisionConfig* precision_config, xla::XlaOp* input_sizes) {
|
||||
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
|
||||
|
||||
int num_dims = attrs.num_spatial_dims + 2;
|
||||
@ -347,8 +366,19 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
|
||||
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
|
||||
std::vector<int64> ones(attrs.num_spatial_dims, 1);
|
||||
xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
|
||||
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
|
||||
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
|
||||
if (out_backprop_shape.is_dynamic_dimension(dim)) {
|
||||
TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
|
||||
<< "Dynamic convolution only supports valid and same padding";
|
||||
if (attrs.padding == VALID) {
|
||||
padding_type = xla::PaddingType::PADDING_VALID;
|
||||
}
|
||||
if (attrs.padding == SAME) {
|
||||
padding_type = xla::PaddingType::PADDING_SAME;
|
||||
}
|
||||
}
|
||||
dnums.add_input_spatial_dimensions(dim);
|
||||
dnums.add_kernel_spatial_dimensions(i);
|
||||
dnums.add_output_spatial_dimensions(dim);
|
||||
@ -366,7 +396,15 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
}
|
||||
// Mirror the filter in the spatial dimensions.
|
||||
filter = xla::Rev(filter, kernel_spatial_dims);
|
||||
|
||||
if (padding_type != xla::PaddingType::PADDING_INVALID) {
|
||||
TF_RET_CHECK(input_sizes != nullptr);
|
||||
return xla::DynamicConvInputGrad(
|
||||
*input_sizes, out_backprop, filter, /*window_strides=*/ones, padding,
|
||||
lhs_dilation, rhs_dilation, dnums,
|
||||
/*feature_group_count=*/
|
||||
feature_group_count,
|
||||
/*batch_group_count=*/1, precision_config, padding_type);
|
||||
}
|
||||
// activation gradients
|
||||
// = gradients (with padding and dilation) <conv> mirrored_weights
|
||||
return xla::ConvGeneralDilated(out_backprop, filter, /*window_strides=*/ones,
|
||||
@ -444,9 +482,19 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
for (int i = 0; i < attrs.num_spatial_dims; ++i) {
|
||||
dnums.add_output_spatial_dimensions(i);
|
||||
}
|
||||
|
||||
xla::PaddingType padding_type = xla::PaddingType::PADDING_INVALID;
|
||||
for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
|
||||
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
|
||||
if (activations_shape.is_dynamic_dimension(dim)) {
|
||||
TF_RET_CHECK(attrs.padding == VALID || attrs.padding == SAME)
|
||||
<< "Dynamic convolution only supports valid and same padding";
|
||||
if (attrs.padding == VALID) {
|
||||
padding_type = xla::PaddingType::PADDING_VALID;
|
||||
}
|
||||
if (attrs.padding == SAME) {
|
||||
padding_type = xla::PaddingType::PADDING_SAME;
|
||||
}
|
||||
}
|
||||
dnums.add_input_spatial_dimensions(dim);
|
||||
dnums.add_kernel_spatial_dimensions(dim);
|
||||
rhs_dilation[i] = dims.spatial_dims[i].stride;
|
||||
@ -503,12 +551,20 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
//
|
||||
// This is done by specifying the window dilation factors in the
|
||||
// convolution HLO below.
|
||||
|
||||
filter_backprop = xla::ConvGeneralDilated(
|
||||
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
|
||||
rhs_dilation, dnums,
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/batch_group_count, precision_config);
|
||||
if (padding_type != xla::PaddingType::PADDING_INVALID) {
|
||||
filter_backprop = xla::DynamicConvKernelGrad(
|
||||
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
|
||||
rhs_dilation, dnums,
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/batch_group_count, precision_config,
|
||||
padding_type);
|
||||
} else {
|
||||
filter_backprop = xla::ConvGeneralDilated(
|
||||
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
|
||||
rhs_dilation, dnums,
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/batch_group_count, precision_config);
|
||||
}
|
||||
|
||||
if (attrs.depthwise) {
|
||||
filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions());
|
||||
|
@ -64,7 +64,8 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config = nullptr);
|
||||
const xla::PrecisionConfig* precision_config = nullptr,
|
||||
xla::XlaOp* input_sizes = nullptr);
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
StringPiece type_string, xla::XlaOp activations,
|
||||
const xla::Shape& filter_shape, xla::XlaOp gradients,
|
||||
|
@ -112,10 +112,10 @@ class ConvBackpropInputOp : public XlaOpKernel {
|
||||
"The rank of the specified input shape must be "
|
||||
"num_spatial_dims + 2. Expected ",
|
||||
attrs_.num_spatial_dims + 2, " got ", input_shape.rank()));
|
||||
|
||||
xla::StatusOr<xla::XlaOp> in_backprop =
|
||||
MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
|
||||
ctx->Input(1), ctx->Input(2), attrs_);
|
||||
xla::XlaOp input_sizes = ctx->Input(0);
|
||||
xla::StatusOr<xla::XlaOp> in_backprop = MakeXlaBackpropInputConvOp(
|
||||
ctx->op_kernel().type_string(), input_shape, ctx->Input(1),
|
||||
ctx->Input(2), attrs_, nullptr, &input_sizes);
|
||||
OP_REQUIRES_OK(ctx, in_backprop.status());
|
||||
ctx->SetOutput(0, in_backprop.ValueOrDie());
|
||||
}
|
||||
|
@ -307,7 +307,8 @@ REGISTER_OP("XlaSetDynamicDimensionSize")
|
||||
.Input("size: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
// Use unknown shape to prevent constant folding.
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(
|
||||
R"doc(Make a static dimension into a xla bounded dynamic dimension.
|
||||
The current static dimension size will become the bound and the second
|
||||
|
@ -1439,6 +1439,110 @@ XlaOp XlaBuilder::ConvGeneralDilated(
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||
std::vector<int64> window_dimensions(
|
||||
dimension_numbers.kernel_spatial_dimensions_size());
|
||||
for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
|
||||
window_dimensions[i] =
|
||||
rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
|
||||
window_dimensions, window_strides,
|
||||
padding, lhs_dilation, rhs_dilation));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
*lhs_shape, *rhs_shape, feature_group_count,
|
||||
batch_group_count, window, dimension_numbers));
|
||||
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
*instr.mutable_window() = window;
|
||||
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
|
||||
instr.set_feature_group_count(feature_group_count);
|
||||
instr.set_batch_group_count(batch_group_count);
|
||||
instr.set_padding_type(padding_type);
|
||||
|
||||
if (precision_config != nullptr) {
|
||||
*instr.mutable_precision_config() = *precision_config;
|
||||
}
|
||||
return std::move(instr);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::DynamicConvInputGrad(
|
||||
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstructionProto instr,
|
||||
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count,
|
||||
precision_config, padding_type));
|
||||
|
||||
instr.set_custom_call_target("DynamicConvolutionInputGrad");
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
|
||||
{input_sizes, lhs, rhs});
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::DynamicConvKernelGrad(
|
||||
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstructionProto instr,
|
||||
DynamicConvInstruction(activations, gradients, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count,
|
||||
precision_config, padding_type));
|
||||
|
||||
instr.set_custom_call_target("DynamicConvolutionKernelGrad");
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
|
||||
{activations, gradients});
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::DynamicConvForward(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstructionProto instr,
|
||||
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count,
|
||||
precision_config, padding_type));
|
||||
instr.set_custom_call_target("DynamicConvolutionForward");
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
|
||||
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
|
||||
absl::Span<const int64> window_strides,
|
||||
@ -3901,6 +4005,49 @@ XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
|
||||
precision_config);
|
||||
}
|
||||
|
||||
XlaOp DynamicConvInputGrad(XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type) {
|
||||
return lhs.builder()->DynamicConvInputGrad(
|
||||
input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
||||
precision_config, padding_type);
|
||||
}
|
||||
|
||||
XlaOp DynamicConvKernelGrad(
|
||||
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
return activations.builder()->DynamicConvKernelGrad(
|
||||
activations, gradients, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
||||
precision_config, padding_type);
|
||||
}
|
||||
|
||||
XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type) {
|
||||
return lhs.builder()->DynamicConvForward(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
precision_config, padding_type);
|
||||
}
|
||||
|
||||
XlaOp Fft(const XlaOp operand, FftType fft_type,
|
||||
absl::Span<const int64> fft_length) {
|
||||
return operand.builder()->Fft(operand, fft_type, fft_length);
|
||||
|
@ -560,6 +560,45 @@ class XlaBuilder {
|
||||
int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
|
||||
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type);
|
||||
|
||||
XlaOp DynamicConvInputGrad(
|
||||
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
|
||||
XlaOp DynamicConvKernelGrad(
|
||||
XlaOp activations, XlaOp gradients,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
|
||||
StatusOr<HloInstructionProto> DynamicConvInstruction(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
|
||||
virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
|
||||
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
|
||||
absl::Span<const int64> window_strides,
|
||||
@ -1057,13 +1096,49 @@ class XlaBuilder {
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config);
|
||||
const PrecisionConfig* precision_confige);
|
||||
friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config);
|
||||
friend XlaOp DynamicConvForward(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
friend XlaOp DynamicConvKernelGrad(
|
||||
XlaOp activations, XlaOp gradients,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
friend XlaOp DynamicConvInputGrad(
|
||||
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
|
||||
friend XlaOp ConvKernelGrad(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config);
|
||||
|
||||
friend XlaOp ConvGeneralDilated(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
@ -1761,6 +1836,34 @@ XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
|
||||
int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
|
||||
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type);
|
||||
|
||||
XlaOp DynamicConvInputGrad(XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type);
|
||||
|
||||
XlaOp DynamicConvKernelGrad(
|
||||
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
|
||||
// Enqueues an FFT instruction onto the computation, of the given type and
|
||||
// with the given FFT length.
|
||||
XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span<const int64> fft_length);
|
||||
|
@ -2721,17 +2721,35 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dynamic_window_utils",
|
||||
srcs = ["dynamic_window_utils.cc"],
|
||||
hdrs = ["dynamic_window_utils.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core/platform:macros",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dynamic_dimension_inference",
|
||||
srcs = ["dynamic_dimension_inference.cc"],
|
||||
hdrs = ["dynamic_dimension_inference.h"],
|
||||
deps = [
|
||||
":dynamic_window_utils",
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":tuple_util",
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -2751,10 +2769,13 @@ cc_library(
|
||||
hdrs = ["dynamic_padder.h"],
|
||||
deps = [
|
||||
":dynamic_dimension_inference",
|
||||
":dynamic_window_utils",
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_creation_utils",
|
||||
":hlo_dce",
|
||||
":hlo_pass",
|
||||
":hlo_verifier",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -2762,6 +2783,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -29,10 +30,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
@ -157,6 +158,18 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
|
||||
using DynamicDimensionFn = std::function<Status(
|
||||
ShapeIndex index, int64 dimension, HloInstruction* dynamic_size)>;
|
||||
|
||||
Status HandleDynamicConvolutionForward(HloInstruction* hlo,
|
||||
int64 operand_index, int64 dimension,
|
||||
HloInstruction* dynamic_size);
|
||||
|
||||
Status HandleDynamicConvolutionKernelGrad(HloInstruction* hlo,
|
||||
int64 operand_index,
|
||||
int64 dimension);
|
||||
|
||||
Status HandleDynamicConvolutionInputGrad(HloInstruction* hlo,
|
||||
int64 operand_index,
|
||||
int64 dimension);
|
||||
|
||||
Status ForEachOperandDynamicDimension(HloInstruction* inst,
|
||||
const OperandDynamicDimensionFn&);
|
||||
Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
|
||||
@ -256,6 +269,20 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
|
||||
return Status::OK();
|
||||
}
|
||||
if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
|
||||
return HandleDynamicConvolutionInputGrad(hlo, operand_index,
|
||||
dimension);
|
||||
}
|
||||
|
||||
if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
|
||||
return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
|
||||
dimension);
|
||||
}
|
||||
|
||||
if (hlo->custom_call_target() == "DynamicConvolutionForward") {
|
||||
return HandleDynamicConvolutionForward(hlo, operand_index, dimension,
|
||||
dynamic_size);
|
||||
}
|
||||
return Unimplemented(
|
||||
"CustomCall \"%s\" is not supported to have a dynamic dimension",
|
||||
hlo->custom_call_target());
|
||||
@ -591,6 +618,70 @@ Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
|
||||
HloInstruction* hlo, int64 operand_index, int64 dimension,
|
||||
HloInstruction* dynamic_size) {
|
||||
TF_RET_CHECK(operand_index == 0);
|
||||
const ConvolutionDimensionNumbers& dimension_numbers =
|
||||
hlo->convolution_dimension_numbers();
|
||||
|
||||
if (dimension == dimension_numbers.input_batch_dimension()) {
|
||||
// Batch dimension is propagated without any changes.
|
||||
parent_->SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
|
||||
dynamic_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (int64 spatial_dim_index = 0;
|
||||
spatial_dim_index < dimension_numbers.input_spatial_dimensions_size();
|
||||
++spatial_dim_index) {
|
||||
int64 input_spatial_dim =
|
||||
dimension_numbers.input_spatial_dimensions(spatial_dim_index);
|
||||
int64 output_spatial_dim =
|
||||
dimension_numbers.output_spatial_dimensions(spatial_dim_index);
|
||||
if (dimension == input_spatial_dim) {
|
||||
// This is a dynamic spatial dimension. Calculate the output size.
|
||||
WindowDimension window_dim = hlo->window().dimensions(spatial_dim_index);
|
||||
DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
|
||||
dynamic_size, window_dim.size(), window_dim.window_dilation(),
|
||||
window_dim.stride(), hlo->padding_type());
|
||||
TF_RET_CHECK(window_dim.base_dilation() == 1);
|
||||
parent_->SetDynamicSize(hlo, {}, output_spatial_dim,
|
||||
dynamic_window_dims.output_size);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return Unimplemented(
|
||||
"XLA doesn't support dynamic input feature dimension on convolution: %s",
|
||||
hlo->ToString());
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionInputGrad(
|
||||
HloInstruction* hlo, int64 operand_index, int64 dimension) {
|
||||
// The output size of convolution input grad is corresponding input size.
|
||||
HloInstruction* input_sizes = hlo->mutable_operand(0);
|
||||
HloComputation* comp = hlo->parent();
|
||||
TF_RET_CHECK(input_sizes->shape().rank() == 1) << hlo->ToString();
|
||||
TF_RET_CHECK(input_sizes->shape().element_type() == S32) << hlo->ToString();
|
||||
TF_RET_CHECK(input_sizes->shape().dimensions(0) ==
|
||||
hlo->shape().dimensions_size())
|
||||
<< hlo->ToString();
|
||||
// Slice to get corresponding input size.
|
||||
HloInstruction* slice = comp->AddInstruction(
|
||||
HloInstruction::CreateSlice(ShapeUtil::MakeShape(S32, {1}), input_sizes,
|
||||
{dimension}, {dimension + 1}, {1}));
|
||||
HloInstruction* reshape = comp->AddInstruction(
|
||||
HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, reshape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionKernelGrad(
|
||||
HloInstruction* hlo, int64 operand_index, int64 dimension) {
|
||||
// Dynamic convolution kernel grad produces static shape outputs.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
|
||||
HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
@ -1540,4 +1631,15 @@ HloInstruction* DynamicDimensionInference::GetDynamicSize(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> DynamicDimensionInference::GetDynamicSizes(
|
||||
HloInstruction* inst, const ShapeIndex& index) const {
|
||||
CHECK(ShapeUtil::IndexIsValid(inst->shape(), index));
|
||||
const int64 rank = ShapeUtil::GetSubshape(inst->shape(), index).rank();
|
||||
std::vector<HloInstruction*> result(rank, nullptr);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
result[i] = GetDynamicSize(inst, {}, i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -51,8 +51,13 @@ class DynamicDimensionInference {
|
||||
HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
|
||||
int64 dim) const;
|
||||
|
||||
// Returns if current instruction contains any dynamic dimension. Recursively
|
||||
// go into tuples.
|
||||
// Returns dynamic sizes of all dimensions of `inst`'s leaf node at `index`.
|
||||
// Static sizes are represented by nullptr.
|
||||
std::vector<HloInstruction*> GetDynamicSizes(HloInstruction* inst,
|
||||
const ShapeIndex& index) const;
|
||||
|
||||
// Returns if current instruction contains any dynamic dimension.
|
||||
// Recursively go into tuples.
|
||||
bool HasDynamicDimension(HloInstruction* inst) const;
|
||||
|
||||
// Forward dynamic dimension size at `dim` from `inst` to `new_inst`.
|
||||
|
@ -27,8 +27,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
@ -37,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
@ -122,9 +125,9 @@ StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kDomain:
|
||||
return nullptr;
|
||||
// Assume that custom calls created by the client are valid with padded
|
||||
// dynamic dimensions.
|
||||
case HloOpcode::kCustomCall:
|
||||
// Assume that custom calls created by the client are valid with padded
|
||||
// dynamic dimensions.
|
||||
return nullptr;
|
||||
default:
|
||||
return UnimplementedStrCat("Unimplemented padding for instruction: ",
|
||||
@ -721,6 +724,262 @@ Status RewriteDynamicReshapeSingleGroup(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
HloInstruction* RewriteInputWithDynamicPadding(
|
||||
HloInstruction* conv, HloInstruction* input,
|
||||
absl::Span<HloInstruction*> padding_before, Window* input_window) {
|
||||
HloComputation* comp = conv->parent();
|
||||
auto dnums = conv->convolution_dimension_numbers();
|
||||
HloInstruction* zero_s32 = comp->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
|
||||
Shape padded_grad_shape = input->shape();
|
||||
PaddingConfig padding_configs;
|
||||
for (int64 i = 0; i < input->shape().rank(); ++i) {
|
||||
PaddingConfig::PaddingConfigDimension padding_dim;
|
||||
*padding_configs.add_dimensions() = padding_dim;
|
||||
}
|
||||
std::vector<HloInstruction*> start_indices(input->shape().rank(), zero_s32);
|
||||
for (int64 spatial_dim_index = 0;
|
||||
spatial_dim_index < dnums.input_spatial_dimensions_size();
|
||||
++spatial_dim_index) {
|
||||
int64 input_spatial_dim = dnums.input_spatial_dimensions(spatial_dim_index);
|
||||
if (padding_before[spatial_dim_index] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
WindowDimension* window_dim =
|
||||
input_window->mutable_dimensions(spatial_dim_index);
|
||||
auto* padding_dim = padding_configs.mutable_dimensions(input_spatial_dim);
|
||||
const int64 dilated_window_size = window_util::DilatedBound(
|
||||
window_dim->size(), window_dim->window_dilation());
|
||||
// Chosoe dilated window size as low padding and static padding_high +
|
||||
// padding_low as high padding to make sure the following dynamic slice is
|
||||
// valid.
|
||||
//
|
||||
// See go/xla-dynamic-spatial-dim for more details.
|
||||
padding_dim->set_edge_padding_low(dilated_window_size);
|
||||
padding_dim->set_edge_padding_high(window_dim->padding_high() +
|
||||
window_dim->padding_low());
|
||||
padding_dim->set_interior_padding(window_dim->base_dilation() - 1);
|
||||
HloInstruction* slicing_start =
|
||||
comp->AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract,
|
||||
comp->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(padding_dim->edge_padding_low()))),
|
||||
padding_before[spatial_dim_index]));
|
||||
start_indices[input_spatial_dim] = slicing_start;
|
||||
|
||||
padded_grad_shape.mutable_dimensions()[input_spatial_dim] =
|
||||
window_dim->padding_low() +
|
||||
window_util::DilatedBound(
|
||||
padded_grad_shape.dimensions(input_spatial_dim),
|
||||
window_dim->base_dilation()) +
|
||||
window_dim->padding_high();
|
||||
window_dim->clear_padding_high();
|
||||
window_dim->clear_padding_low();
|
||||
window_dim->set_base_dilation(1);
|
||||
input->mutable_shape()->set_dynamic_dimension(input_spatial_dim, false);
|
||||
}
|
||||
// Reconstruct dynamic padding using pad and dynamic slice.
|
||||
HloInstruction* pad =
|
||||
MakePadHlo(input,
|
||||
comp->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(conv->shape().element_type()))),
|
||||
padding_configs)
|
||||
.ValueOrDie();
|
||||
input = comp->AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
padded_grad_shape, pad, start_indices, padded_grad_shape.dimensions()));
|
||||
return input;
|
||||
}
|
||||
|
||||
StatusOr<bool> RewriteDynamicConvolutionInputGrad(
|
||||
HloInstruction* custom_call_conv,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
HloInstruction* grad = custom_call_conv->mutable_operand(1);
|
||||
HloInstruction* kernel = custom_call_conv->mutable_operand(2);
|
||||
TF_RET_CHECK(kernel->shape().is_static());
|
||||
auto dnums = custom_call_conv->convolution_dimension_numbers();
|
||||
HloComputation* comp = custom_call_conv->parent();
|
||||
Window window = custom_call_conv->window();
|
||||
HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(custom_call_conv->shape().element_type())));
|
||||
std::vector<HloInstruction*> padding_before(
|
||||
dnums.input_spatial_dimensions_size(), nullptr);
|
||||
for (int64 spatial_dim_index = 0;
|
||||
spatial_dim_index < dnums.input_spatial_dimensions_size();
|
||||
++spatial_dim_index) {
|
||||
int64 input_spatial_dim = dnums.input_spatial_dimensions(spatial_dim_index);
|
||||
HloInstruction* operand_dynamic_size =
|
||||
dynamic_dimension_inference->GetDynamicSize(
|
||||
custom_call_conv->mutable_operand(1), {}, input_spatial_dim);
|
||||
if (operand_dynamic_size == nullptr) {
|
||||
continue;
|
||||
}
|
||||
grad = PadWithScalar(grad, input_spatial_dim, operand_dynamic_size, zero);
|
||||
HloInstruction* slice = comp->AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(S32, {1}), custom_call_conv->mutable_operand(0),
|
||||
{input_spatial_dim}, {input_spatial_dim + 1}, {1}));
|
||||
HloInstruction* dynamic_input_size = comp->AddInstruction(
|
||||
HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
|
||||
const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
|
||||
// Window stride of forward prop is same as base dilation of backward prop.
|
||||
DynamicWindowDims dynamic_window_dims = GetWindowedInputGradSize(
|
||||
dynamic_input_size, /*window_size=*/window_dim.size(),
|
||||
/*window_dilation=*/window_dim.window_dilation(),
|
||||
/*window_stride=*/window_dim.base_dilation(),
|
||||
custom_call_conv->padding_type());
|
||||
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
|
||||
}
|
||||
|
||||
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
|
||||
grad = RewriteInputWithDynamicPadding(
|
||||
custom_call_conv, grad, absl::MakeSpan(padding_before), &window);
|
||||
}
|
||||
|
||||
PrecisionConfig precision_config;
|
||||
if (custom_call_conv->precision_config().operand_precision_size() == 3) {
|
||||
// We are not interested in the precision config of the first operand, which
|
||||
// is the input_sizes.
|
||||
*precision_config.mutable_operand_precision() = {
|
||||
custom_call_conv->precision_config().operand_precision().begin() + 1,
|
||||
custom_call_conv->precision_config().operand_precision().end()};
|
||||
}
|
||||
HloInstruction* static_conv = comp->AddInstruction(
|
||||
HloInstruction::CreateConvolve(
|
||||
custom_call_conv->shape(), grad, kernel,
|
||||
custom_call_conv->feature_group_count(),
|
||||
custom_call_conv->batch_group_count(), window,
|
||||
custom_call_conv->convolution_dimension_numbers(),
|
||||
custom_call_conv->precision_config()),
|
||||
"ConvBackwardInput");
|
||||
TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
|
||||
TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
|
||||
custom_call_conv, static_conv, {}));
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> RewriteDynamicConvolutionForward(
|
||||
HloInstruction* custom_call_conv,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
HloInstruction* input = custom_call_conv->mutable_operand(0);
|
||||
HloInstruction* kernel = custom_call_conv->mutable_operand(1);
|
||||
TF_RET_CHECK(kernel->shape().is_static());
|
||||
TF_RET_CHECK(input->shape().is_dynamic());
|
||||
HloComputation* comp = custom_call_conv->parent();
|
||||
Window window = custom_call_conv->window();
|
||||
auto dnums = custom_call_conv->convolution_dimension_numbers();
|
||||
HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(custom_call_conv->shape().element_type())));
|
||||
std::vector<HloInstruction*> padding_before(
|
||||
dnums.input_spatial_dimensions_size(), nullptr);
|
||||
for (int64 spatial_dim_index = 0;
|
||||
spatial_dim_index < dnums.input_spatial_dimensions_size();
|
||||
++spatial_dim_index) {
|
||||
int64 input_spatial_dim = dnums.input_spatial_dimensions(spatial_dim_index);
|
||||
HloInstruction* operand_dynamic_size =
|
||||
dynamic_dimension_inference->GetDynamicSize(
|
||||
custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
|
||||
if (operand_dynamic_size == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
input = PadWithScalar(input, input_spatial_dim, operand_dynamic_size, zero);
|
||||
const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
|
||||
DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
|
||||
operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
|
||||
window_dim.stride(), custom_call_conv->padding_type());
|
||||
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
|
||||
}
|
||||
|
||||
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
|
||||
input = RewriteInputWithDynamicPadding(
|
||||
custom_call_conv, input, absl::MakeSpan(padding_before), &window);
|
||||
}
|
||||
|
||||
HloInstruction* static_conv = comp->AddInstruction(
|
||||
HloInstruction::CreateConvolve(
|
||||
custom_call_conv->shape(), input, kernel,
|
||||
custom_call_conv->feature_group_count(),
|
||||
custom_call_conv->batch_group_count(), window,
|
||||
custom_call_conv->convolution_dimension_numbers(),
|
||||
custom_call_conv->precision_config()),
|
||||
"ConvForward");
|
||||
TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
|
||||
TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
|
||||
custom_call_conv, static_conv, {}));
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
|
||||
HloInstruction* custom_call_conv,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
HloInstruction* activations = custom_call_conv->mutable_operand(0);
|
||||
HloInstruction* gradients = custom_call_conv->mutable_operand(1);
|
||||
TF_RET_CHECK(activations->shape().is_dynamic());
|
||||
TF_RET_CHECK(gradients->shape().is_dynamic());
|
||||
HloComputation* comp = custom_call_conv->parent();
|
||||
Window window = custom_call_conv->window();
|
||||
auto dnums = custom_call_conv->convolution_dimension_numbers();
|
||||
HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(custom_call_conv->shape().element_type())));
|
||||
std::vector<HloInstruction*> padding_before(
|
||||
dnums.input_spatial_dimensions_size(), nullptr);
|
||||
for (int64 spatial_dim_index = 0;
|
||||
spatial_dim_index < dnums.input_spatial_dimensions_size();
|
||||
++spatial_dim_index) {
|
||||
int64 input_spatial_dim = dnums.input_spatial_dimensions(spatial_dim_index);
|
||||
int64 kernel_spatial_dim =
|
||||
dnums.kernel_spatial_dimensions(spatial_dim_index);
|
||||
HloInstruction* activations_dynamic_size =
|
||||
dynamic_dimension_inference->GetDynamicSize(
|
||||
custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
|
||||
if (activations_dynamic_size != nullptr) {
|
||||
activations = PadWithScalar(activations, input_spatial_dim,
|
||||
activations_dynamic_size, zero);
|
||||
}
|
||||
|
||||
HloInstruction* gradients_dynamic_size =
|
||||
dynamic_dimension_inference->GetDynamicSize(
|
||||
custom_call_conv->mutable_operand(1), {}, kernel_spatial_dim);
|
||||
if (gradients_dynamic_size != nullptr) {
|
||||
gradients = PadWithScalar(gradients, kernel_spatial_dim,
|
||||
gradients_dynamic_size, zero);
|
||||
}
|
||||
if (activations_dynamic_size == nullptr ||
|
||||
gradients_dynamic_size == nullptr) {
|
||||
TF_RET_CHECK(activations_dynamic_size == nullptr &&
|
||||
gradients_dynamic_size == nullptr);
|
||||
continue;
|
||||
}
|
||||
int64 output_spatial_dim =
|
||||
dnums.output_spatial_dimensions(spatial_dim_index);
|
||||
const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
|
||||
DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
|
||||
activations_dynamic_size, /*window_size=*/
|
||||
custom_call_conv->shape().dimensions(output_spatial_dim),
|
||||
/*window_dilation=*/window_dim.stride(),
|
||||
/*window_stride=*/window_dim.window_dilation(),
|
||||
custom_call_conv->padding_type());
|
||||
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
|
||||
}
|
||||
|
||||
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
|
||||
activations = RewriteInputWithDynamicPadding(
|
||||
custom_call_conv, activations, absl::MakeSpan(padding_before), &window);
|
||||
}
|
||||
|
||||
HloInstruction* static_conv = comp->AddInstruction(
|
||||
HloInstruction::CreateConvolve(
|
||||
custom_call_conv->shape(), activations, gradients,
|
||||
custom_call_conv->feature_group_count(),
|
||||
custom_call_conv->batch_group_count(), window,
|
||||
custom_call_conv->convolution_dimension_numbers(),
|
||||
custom_call_conv->precision_config()),
|
||||
"ConvBackwardGrad");
|
||||
TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
|
||||
TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
|
||||
custom_call_conv, static_conv, {}));
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> RewriteDynamicConcat(
|
||||
HloInstruction* concat,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
@ -1323,6 +1582,23 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
inst, static_reshape, {}));
|
||||
continue;
|
||||
}
|
||||
if (inst->IsCustomCall("DynamicConvolutionInputGrad")) {
|
||||
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionInputGrad(
|
||||
inst, &dynamic_dimension_inference));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->IsCustomCall("DynamicConvolutionForward")) {
|
||||
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionForward(
|
||||
inst, &dynamic_dimension_inference));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->IsCustomCall("DynamicConvolutionKernelGrad")) {
|
||||
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionKernelGrad(
|
||||
inst, &dynamic_dimension_inference));
|
||||
continue;
|
||||
}
|
||||
for (int64 operand_num = 0; operand_num < inst->operand_count();
|
||||
++operand_num) {
|
||||
HloInstruction* original_operand = inst->mutable_operand(operand_num);
|
||||
@ -1398,9 +1674,9 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
changed = changed || replaced_set_bound;
|
||||
}
|
||||
}
|
||||
|
||||
HloDCE dce;
|
||||
TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
|
||||
|
||||
VLOG(2) << "Post DynamicPadder HLO:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
dynamic_padding_gauge->GetCell()->Set(changed);
|
||||
|
150
tensorflow/compiler/xla/service/dynamic_window_utils.cc
Normal file
150
tensorflow/compiler/xla/service/dynamic_window_utils.cc
Normal file
@ -0,0 +1,150 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
// HloOp wraps an instuction pointer to do arithmetic based on operator
|
||||
// overloading.
|
||||
//
|
||||
// TODO(yunxing): This is only used internally to this file to provide a
|
||||
// convenient way to do operator overloadding. Find out an idiom and merge this
|
||||
// with hlo_creation_utils.
|
||||
class HloOp {
|
||||
public:
|
||||
HloOp() = default;
|
||||
explicit HloOp(HloInstruction* inst) : inst_(inst) {}
|
||||
void SetName(const std::string& name) {
|
||||
inst_->SetAndSanitizeName(name);
|
||||
if (inst_->GetModule() != nullptr) {
|
||||
inst_->UniquifyName(&inst_->GetModule()->instruction_name_uniquer());
|
||||
}
|
||||
}
|
||||
HloInstruction* get() { return inst_; }
|
||||
|
||||
private:
|
||||
HloInstruction* inst_ = nullptr;
|
||||
};
|
||||
HloOp BinaryOp(HloOp x, HloOp y, HloOpcode opcode,
|
||||
const std::string& name = "") {
|
||||
CHECK_EQ(x.get()->parent(), y.get()->parent());
|
||||
Shape binary_op_shape =
|
||||
ShapeInference::InferBinaryOpShape(opcode, x.get(), y.get()).ValueOrDie();
|
||||
return HloOp(x.get()->parent()->AddInstruction(
|
||||
HloInstruction::CreateBinary(binary_op_shape, opcode, x.get(), y.get()),
|
||||
name));
|
||||
}
|
||||
HloOp operator+(HloOp x, HloOp y) { return BinaryOp(x, y, HloOpcode::kAdd); }
|
||||
|
||||
HloOp operator-(HloOp x, HloOp y) {
|
||||
return BinaryOp(x, y, HloOpcode::kSubtract);
|
||||
}
|
||||
|
||||
HloOp operator*(HloOp x, HloOp y) {
|
||||
return BinaryOp(x, y, HloOpcode::kMultiply);
|
||||
}
|
||||
|
||||
HloOp operator/(HloOp x, HloOp y) { return BinaryOp(x, y, HloOpcode::kDivide); }
|
||||
|
||||
HloOp Maximum(HloOp x, HloOp y, const std::string& name = "") {
|
||||
return BinaryOp(x, y, HloOpcode::kMaximum, name);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
HloOp ConstantR0(HloComputation* comp, NativeT value,
|
||||
const std::string& name = "") {
|
||||
return HloOp(comp->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)),
|
||||
name));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
HloOp One(HloComputation* comp) {
|
||||
return ConstantR0<NativeT>(comp, 1, "one");
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
HloOp Zero(HloComputation* comp) {
|
||||
return ConstantR0<NativeT>(comp, 0, "zero");
|
||||
}
|
||||
|
||||
HloOp EffectiveFilterSize(HloComputation* comp, int64 window_size,
|
||||
int64 window_dilation) {
|
||||
return ConstantR0<int32>(comp, (window_size - 1) * window_dilation + 1,
|
||||
"effective_filter_size");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
DynamicWindowDims GetWindowedOutputSize(HloInstruction* input_size,
|
||||
int64 window_size,
|
||||
int64 window_dilation,
|
||||
int64 window_stride,
|
||||
PaddingType padding_type) {
|
||||
HloComputation* comp = input_size->parent();
|
||||
DynamicWindowDims result;
|
||||
|
||||
HloOp stride = ConstantR0<int32>(comp, window_stride, "stride");
|
||||
HloOp effective_filter_size =
|
||||
EffectiveFilterSize(comp, window_size, window_dilation);
|
||||
if (padding_type == PaddingType::PADDING_VALID) {
|
||||
HloOp output =
|
||||
(HloOp(input_size) + stride - effective_filter_size) / stride;
|
||||
result.output_size = output.get();
|
||||
result.padding_before = Zero<int32>(comp).get();
|
||||
} else if (padding_type == PaddingType::PADDING_SAME) {
|
||||
HloOp output = (HloOp(input_size) + stride - One<int32>(comp)) / stride;
|
||||
HloOp padding_needed = Maximum(
|
||||
Zero<int32>(comp), (output - One<int32>(comp)) * stride +
|
||||
effective_filter_size - HloOp(input_size));
|
||||
HloOp padding_before = padding_needed / ConstantR0<int32>(comp, 2);
|
||||
result.padding_before = padding_before.get();
|
||||
result.output_size = output.get();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
DynamicWindowDims GetWindowedInputGradSize(HloInstruction* input_size,
|
||||
int64 window_size,
|
||||
int64 window_dilation,
|
||||
int64 window_stride,
|
||||
PaddingType padding_type) {
|
||||
HloComputation* comp = input_size->parent();
|
||||
DynamicWindowDims result;
|
||||
HloOp effective_filter_size =
|
||||
ConstantR0<int32>(comp, (window_size - 1) * window_dilation + 1);
|
||||
HloOp stride = ConstantR0<int32>(comp, window_stride);
|
||||
DynamicWindowDims forward_dims = GetWindowedOutputSize(
|
||||
input_size, window_size, window_dilation, window_stride, padding_type);
|
||||
HloOp output_size =
|
||||
(HloOp(forward_dims.output_size) - One<int32>(comp)) * stride +
|
||||
One<int32>(comp);
|
||||
HloOp padding_before = effective_filter_size - One<int32>(comp) -
|
||||
HloOp(forward_dims.padding_before);
|
||||
result.output_size = output_size.get();
|
||||
result.padding_before = padding_before.get();
|
||||
return result;
|
||||
}
|
||||
} // namespace xla
|
51
tensorflow/compiler/xla/service/dynamic_window_utils.h
Normal file
51
tensorflow/compiler/xla/service/dynamic_window_utils.h
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_WINDOW_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_WINDOW_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
namespace xla {
|
||||
struct DynamicWindowDims {
|
||||
HloInstruction* padding_before;
|
||||
HloInstruction* output_size;
|
||||
};
|
||||
|
||||
// This mirrors the logic in GetWindowedOutputSizeVerboseV2 but with HLOs as
|
||||
// inputs and outputs.
|
||||
DynamicWindowDims GetWindowedOutputSize(HloInstruction* input_size,
|
||||
int64 window_size,
|
||||
int64 window_dilation,
|
||||
int64 window_stride,
|
||||
PaddingType padding_type);
|
||||
|
||||
DynamicWindowDims GetWindowedInputGradSize(HloInstruction* input_size,
|
||||
int64 window_size,
|
||||
int64 window_dilation,
|
||||
int64 window_stride,
|
||||
PaddingType padding_type);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_WINDOW_UTILS_H_
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Serialization of HloInstruction.
|
||||
// Next ID: 75
|
||||
// Next ID: 76
|
||||
message HloInstructionProto {
|
||||
reserved 10;
|
||||
reserved "parameter_name";
|
||||
@ -259,6 +259,9 @@ message HloInstructionProto {
|
||||
|
||||
// Specifies if this is a cross-program-prefetch, used by kCopyStart.
|
||||
bool is_cross_program_prefetch = 73;
|
||||
|
||||
// If a convolution is dynamic, a dynamic padding type will be specified.
|
||||
xla.PaddingType padding_type = 75;
|
||||
}
|
||||
|
||||
// Serialization of HloComputation.
|
||||
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -580,6 +581,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
|
||||
custom_call_instr->set_custom_call_has_side_effect(
|
||||
proto.custom_call_has_side_effect());
|
||||
custom_call_instr->set_padding_type(proto.padding_type());
|
||||
|
||||
PrecisionConfig precision_config = proto.precision_config();
|
||||
precision_config.mutable_operand_precision()->Resize(
|
||||
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
|
||||
*custom_call_instr->mutable_precision_config() = precision_config;
|
||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_to_operand_aliasing;
|
||||
for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
|
||||
@ -2951,6 +2958,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
|
||||
|
||||
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
|
||||
|
||||
bool HloInstruction::IsCustomCall(absl::string_view target) const {
|
||||
return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
|
||||
}
|
||||
|
||||
bool HloInstruction::IsInputFusion() const {
|
||||
return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
|
||||
}
|
||||
@ -3826,6 +3837,10 @@ const PrecisionConfig& HloInstruction::precision_config() const {
|
||||
if (auto* dot = DynCast<HloDotInstruction>(this)) {
|
||||
return dot->precision_config();
|
||||
}
|
||||
|
||||
if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
|
||||
return custom_call->precision_config();
|
||||
}
|
||||
LOG(FATAL) << "Unimplemented method.";
|
||||
}
|
||||
|
||||
@ -4183,6 +4198,10 @@ const PaddingConfig& HloInstruction::padding_config() const {
|
||||
return Cast<HloPadInstruction>(this)->padding_config();
|
||||
}
|
||||
|
||||
PaddingType HloInstruction::padding_type() const {
|
||||
return Cast<HloCustomCallInstruction>(this)->padding_type();
|
||||
}
|
||||
|
||||
PaddingConfig* HloInstruction::mutable_padding_config() {
|
||||
return Cast<HloPadInstruction>(this)->mutable_padding_config();
|
||||
}
|
||||
|
@ -1366,6 +1366,8 @@ class HloInstruction {
|
||||
// instruction.
|
||||
bool IsFusible() const;
|
||||
|
||||
bool IsCustomCall(absl::string_view target) const;
|
||||
|
||||
// Returns the sharding applied to this operator.
|
||||
// REQUIRES: has_sharding() is true.
|
||||
const HloSharding& sharding() const {
|
||||
@ -1843,6 +1845,9 @@ class HloInstruction {
|
||||
const PaddingConfig& padding_config() const;
|
||||
PaddingConfig* mutable_padding_config();
|
||||
|
||||
// Delegates to HloConvolutionInstruction::padding_type.
|
||||
PaddingType padding_type() const;
|
||||
|
||||
// Delegates to HloDynamicSliceInstruction::slice_sizes.
|
||||
int64 slice_sizes(int64 dimension) const;
|
||||
|
||||
|
@ -2199,7 +2199,6 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
|
||||
if (!precision_config_string.empty()) {
|
||||
extra.push_back(precision_config_string);
|
||||
}
|
||||
|
||||
return extra;
|
||||
}
|
||||
|
||||
@ -2346,6 +2345,7 @@ HloCustomCallInstruction::HloCustomCallInstruction(
|
||||
feature_group_count_(1),
|
||||
batch_group_count_(1),
|
||||
layout_constrained_(false),
|
||||
padding_type_(PaddingType::PADDING_INVALID),
|
||||
custom_call_has_side_effect_(false) {
|
||||
set_raw_backend_config_string(std::move(opaque));
|
||||
for (auto operand : operands) {
|
||||
@ -2362,6 +2362,7 @@ HloCustomCallInstruction::HloCustomCallInstruction(
|
||||
feature_group_count_(1),
|
||||
batch_group_count_(1),
|
||||
layout_constrained_(false),
|
||||
padding_type_(PaddingType::PADDING_INVALID),
|
||||
custom_call_has_side_effect_(false) {
|
||||
set_raw_backend_config_string(std::move(opaque));
|
||||
for (auto operand : operands) {
|
||||
@ -2379,6 +2380,7 @@ HloCustomCallInstruction::HloCustomCallInstruction(
|
||||
feature_group_count_(1),
|
||||
batch_group_count_(1),
|
||||
layout_constrained_(true),
|
||||
padding_type_(PaddingType::PADDING_INVALID),
|
||||
operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
|
||||
operand_shapes_with_layout.end()),
|
||||
custom_call_has_side_effect_(false) {
|
||||
@ -2400,6 +2402,8 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
|
||||
proto.set_custom_call_target(custom_call_target_);
|
||||
proto.set_feature_group_count(feature_group_count_);
|
||||
proto.set_batch_group_count(batch_group_count_);
|
||||
*proto.mutable_precision_config() = precision_config_;
|
||||
proto.set_padding_type(padding_type_);
|
||||
if (layout_constrained()) {
|
||||
proto.set_constrain_layout(true);
|
||||
for (const Shape& shape : operand_shapes_with_layout_) {
|
||||
@ -2437,6 +2441,13 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
|
||||
if (batch_group_count_ != 1) {
|
||||
extra.push_back(StrCat("batch_group_count=", batch_group_count_));
|
||||
}
|
||||
string precision_config_string = PrecisionConfigToString(precision_config_);
|
||||
if (!precision_config_string.empty()) {
|
||||
extra.push_back(precision_config_string);
|
||||
}
|
||||
if (padding_type_ != PaddingType::PADDING_INVALID) {
|
||||
extra.push_back(StrCat("padding_type=", PaddingType_Name(padding_type())));
|
||||
}
|
||||
// By contract, we print the custom call target even if
|
||||
// options.print_subcomputation_mode() == kOff, because the call target is not
|
||||
// an HloComputation.
|
||||
@ -2492,6 +2503,11 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
|
||||
if (batch_group_count_ != casted_other.batch_group_count_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (padding_type_ != casted_other.padding_type()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (layout_constrained() != casted_other.layout_constrained()) {
|
||||
return false;
|
||||
}
|
||||
@ -2511,6 +2527,10 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
|
||||
casted_other.output_to_operand_aliasing()) {
|
||||
return false;
|
||||
}
|
||||
if (!protobuf_util::ProtobufEquals(precision_config(),
|
||||
casted_other.precision_config())) {
|
||||
return false;
|
||||
}
|
||||
// Note: backend_config comparison is done in Identical, which is the
|
||||
// intended/exposed way to compare computations, and so not repeated here.
|
||||
return custom_call_target_ == casted_other.custom_call_target_;
|
||||
@ -2536,6 +2556,8 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
|
||||
cloned->set_batch_group_count(batch_group_count_);
|
||||
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
||||
cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
|
||||
cloned->set_padding_type(padding_type_);
|
||||
*cloned->mutable_precision_config() = precision_config();
|
||||
return std::move(cloned);
|
||||
}
|
||||
|
||||
|
@ -1453,6 +1453,16 @@ class HloCustomCallInstruction : public HloInstruction {
|
||||
bool custom_call_has_side_effect() const {
|
||||
return custom_call_has_side_effect_;
|
||||
}
|
||||
// Returns padding type used for ops like convolution.
|
||||
PaddingType padding_type() const { return padding_type_; }
|
||||
|
||||
void set_padding_type(PaddingType padding_type) {
|
||||
padding_type_ = padding_type;
|
||||
}
|
||||
|
||||
const PrecisionConfig& precision_config() const { return precision_config_; }
|
||||
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
|
||||
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
@ -1502,6 +1512,11 @@ class HloCustomCallInstruction : public HloInstruction {
|
||||
int64 batch_group_count_;
|
||||
// Whether the result and operand layouts are constrained.
|
||||
bool layout_constrained_;
|
||||
// Information used to communicate to the implementation about the algorithm
|
||||
// used to produce results for convolution instructions.
|
||||
PrecisionConfig precision_config_;
|
||||
// Describes the padding type for convolution instructions.
|
||||
PaddingType padding_type_;
|
||||
// For layout-constrained custom calls, this vector holds the shape with
|
||||
// layout for each operand.
|
||||
std::vector<Shape> operand_shapes_with_layout_;
|
||||
|
@ -193,6 +193,7 @@ class HloParserImpl : public HloParser {
|
||||
kHloComputation,
|
||||
kBracedHloComputationList,
|
||||
kFftType,
|
||||
kPaddingType,
|
||||
kComparisonDirection,
|
||||
kComparisonType,
|
||||
kWindow,
|
||||
@ -328,6 +329,7 @@ class HloParserImpl : public HloParser {
|
||||
bool ParseTiles(std::vector<Tile>* tiles);
|
||||
bool ParseOpcode(HloOpcode* result);
|
||||
bool ParseFftType(FftType* result);
|
||||
bool ParsePaddingType(PaddingType* result);
|
||||
bool ParseComparisonDirection(ComparisonDirection* result);
|
||||
bool ParseComparisonType(Comparison::Type* result);
|
||||
bool ParseFusionKind(HloInstruction::FusionKind* result);
|
||||
@ -1838,6 +1840,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
optional<HloComputation*> to_apply;
|
||||
optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
|
||||
output_to_operand_aliasing;
|
||||
optional<PaddingType> padding_type;
|
||||
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
|
||||
&custom_call_target};
|
||||
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
|
||||
@ -1856,6 +1859,12 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
attrs["output_to_operand_aliasing"] = {/*required=*/false,
|
||||
AttrTy::kInstructionAliasing,
|
||||
&output_to_operand_aliasing};
|
||||
|
||||
attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
|
||||
&padding_type};
|
||||
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
|
||||
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
|
||||
&operand_precision};
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
@ -1921,6 +1930,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
if (batch_group_count.has_value()) {
|
||||
custom_call_instr->set_batch_group_count(*batch_group_count);
|
||||
}
|
||||
if (padding_type.has_value()) {
|
||||
custom_call_instr->set_padding_type(*padding_type);
|
||||
}
|
||||
if (custom_call_has_side_effect.has_value()) {
|
||||
custom_call_instr->set_custom_call_has_side_effect(
|
||||
*custom_call_has_side_effect);
|
||||
@ -1929,6 +1941,15 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
custom_call_instr->set_output_to_operand_aliasing(
|
||||
std::move(*output_to_operand_aliasing));
|
||||
}
|
||||
PrecisionConfig precision_config;
|
||||
if (operand_precision) {
|
||||
*precision_config.mutable_operand_precision() = {
|
||||
operand_precision->begin(), operand_precision->end()};
|
||||
} else {
|
||||
precision_config.mutable_operand_precision()->Resize(
|
||||
operands.size(), PrecisionConfig::DEFAULT);
|
||||
}
|
||||
*custom_call_instr->mutable_precision_config() = precision_config;
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kDot: {
|
||||
@ -3105,6 +3126,14 @@ bool HloParserImpl::ParseAttributeHelper(
|
||||
static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kPaddingType: {
|
||||
PaddingType result;
|
||||
if (!ParsePaddingType(&result)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<PaddingType>*>(attr_out_ptr)->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kComparisonDirection: {
|
||||
ComparisonDirection result;
|
||||
if (!ParseComparisonDirection(&result)) {
|
||||
@ -4246,6 +4275,19 @@ bool HloParserImpl::ParseFftType(FftType* result) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParserImpl::ParsePaddingType(PaddingType* result) {
|
||||
VLOG(3) << "ParsePaddingType";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects padding type");
|
||||
}
|
||||
std::string val = lexer_.GetStrVal();
|
||||
if (!PaddingType_Parse(val, result) || !PaddingType_IsValid(*result)) {
|
||||
return TokenError(StrFormat("expects padding type but sees: %s", val));
|
||||
}
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
|
||||
VLOG(3) << "ParseComparisonDirection";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
|
@ -451,6 +451,20 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
|
||||
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// convolution dynamic
|
||||
{
|
||||
"ConvolutionDynamic",
|
||||
R"(HloModule Convolve1D1Window_0_module
|
||||
|
||||
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
|
||||
%input = f32[1,2,1]{2,1,0} parameter(0)
|
||||
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
|
||||
%filter = f32[1,1,1]{2,1,0} parameter(1)
|
||||
ROOT %custom-call.52 = f32[1,2,1]{2,0,1} custom-call(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}, custom_call_target="DynamicConvolutionForward", metadata={op_type="Conv2D" op_name="conv1d"}
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// convolution rank 2
|
||||
|
@ -1811,19 +1811,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
// Input feature dimension is a contracting dimension, which does not
|
||||
// affect the output dimension size. So we need to do nothing.
|
||||
} else {
|
||||
return InvalidArgument(
|
||||
"Dynamic Spatial Convolution is not supported: lhs shape is %s ",
|
||||
lhs.ToString());
|
||||
for (int64 j = 0; j < dnums.output_spatial_dimensions_size(); ++j) {
|
||||
if (i == dnums.input_spatial_dimensions(j)) {
|
||||
// i is a spatial dimension, find corresponding output spatial
|
||||
// dimension.
|
||||
is_dynamic[dnums.output_spatial_dimensions(j)] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (rhs.is_dynamic_dimension(i)) {
|
||||
if (i == dnums.kernel_input_feature_dimension()) {
|
||||
// Kernel feature dimension does not affect the output dimension size.
|
||||
// So we need to do nothing.
|
||||
} else {
|
||||
} else if (i == dnums.kernel_output_feature_dimension()) {
|
||||
return InvalidArgument(
|
||||
"Dynamic Spatial Convolution is not supported: rhs shape is %s ",
|
||||
"Dynamic output feature dim on convolution kernel is not "
|
||||
"supported: rhs shape is %s ",
|
||||
rhs.ToString());
|
||||
} else {
|
||||
for (int64 j = 0; j < dnums.kernel_spatial_dimensions_size(); ++j) {
|
||||
if (i == dnums.kernel_spatial_dimensions(j)) {
|
||||
// i is a spatial dimension, find corresponding output spatial
|
||||
// dimension.
|
||||
is_dynamic[dnums.output_spatial_dimensions(j)] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -536,6 +536,12 @@ message ConvolutionDimensionNumbers {
|
||||
// Next = 13
|
||||
}
|
||||
|
||||
enum PaddingType {
|
||||
PADDING_INVALID = 0;
|
||||
PADDING_VALID = 1; // Only valid portion of the base are covered.
|
||||
PADDING_SAME = 2; // Extra is added to produce same output size as the input.
|
||||
}
|
||||
|
||||
enum FftType {
|
||||
FFT = 0; // Forward FFT; complex in, complex out.
|
||||
IFFT = 1; // Inverse FFT; complex in, complex out.
|
||||
|
Loading…
x
Reference in New Issue
Block a user