[XLA:GPU] Migrate convolution thunk emission to MLIR

- Map MLIR LHLO Conv operations to ConvolutionThunk.
- Refactor GetGpuConvConfig() to take in a conv descriptor, so that the same function
  can be used to emit a thunk from XLA HLO and MLIR representations.
- Change XlaConvLayoutsToStreamExecutorLayouts() to check shape<->layout
  compatibility instead of layout equality, since some XLA layouts with unit dimensions
   are not preserved in XLA HLO Shape -> MLIR -> XLA HLO Shape path.
- Note that window reversal is not yet representable in HLO/LHLO, so fall back to
   XLA HLO based thunk emission for that case.

PiperOrigin-RevId: 348051790
Change-Id: I46cf3c08bb8d2440481b570146cb4d7607d2843a
This commit is contained in:
Rahul Joshi 2020-12-17 10:50:51 -08:00 committed by TensorFlower Gardener
parent 869022b4d6
commit 8684c6b2e9
9 changed files with 340 additions and 68 deletions

View File

@ -250,6 +250,7 @@ cc_library(
],
hdrs = ["mlir_hlo_to_hlo.h"],
deps = [
":attribute_exporter",
":type_to_shape",
"//tensorflow/compiler/mlir:name_utils",
"//tensorflow/compiler/mlir/hlo",
@ -338,6 +339,24 @@ cc_library(
],
)
cc_library(
name = "attribute_exporter",
srcs = ["attribute_exporter.cc"],
hdrs = ["attribute_exporter.h"],
deps = [
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor:dnn",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "translate_cl_options",
srcs = ["xla_mlir_translate_cl.cc"],

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/stream_executor/dnn.h"
namespace xla {
ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
mlir::mhlo::ConvDimensionNumbers input) {
ConvolutionDimensionNumbers output;
output.set_input_batch_dimension(
input.input_batch_dimension().getValue().getSExtValue());
output.set_input_feature_dimension(
input.input_feature_dimension().getValue().getSExtValue());
for (auto v : input.input_spatial_dimensions().getValues<int64>()) {
output.add_input_spatial_dimensions(v);
}
output.set_kernel_input_feature_dimension(
input.kernel_input_feature_dimension().getValue().getSExtValue());
output.set_kernel_output_feature_dimension(
input.kernel_output_feature_dimension().getValue().getSExtValue());
for (auto v : input.kernel_spatial_dimensions().getValues<int64>()) {
output.add_kernel_spatial_dimensions(v);
}
output.set_output_batch_dimension(
input.output_batch_dimension().getValue().getSExtValue());
output.set_output_feature_dimension(
input.output_feature_dimension().getValue().getSExtValue());
for (auto v : input.output_spatial_dimensions().getValues<int64>()) {
output.add_output_spatial_dimensions(v);
}
return output;
}
StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
llvm::StringRef input) {
llvm::Optional<mlir::lmhlo_gpu::Activation> activation =
mlir::lmhlo_gpu::symbolizeActivation(input);
if (!activation) {
return InternalError("Unexpected activation");
}
switch (activation.getValue()) {
case mlir::lmhlo_gpu::Activation::None:
return stream_executor::dnn::kNone;
case mlir::lmhlo_gpu::Activation::Sigmoid:
return stream_executor::dnn::kSigmoid;
case mlir::lmhlo_gpu::Activation::Tanh:
return stream_executor::dnn::kTanh;
case mlir::lmhlo_gpu::Activation::Relu:
return stream_executor::dnn::kRelu;
case mlir::lmhlo_gpu::Activation::Relu6:
return stream_executor::dnn::kRelu6;
case mlir::lmhlo_gpu::Activation::ReluX:
return stream_executor::dnn::kReluX;
case mlir::lmhlo_gpu::Activation::BandPass:
return stream_executor::dnn::kBandPass;
default:
return InternalError("Unexpected activation");
}
}
} // namespace xla

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/dnn.h"
namespace xla {
// Converts the conv dimensions attribute to XLA HLO.
ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
mlir::mhlo::ConvDimensionNumbers input);
StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
llvm::StringRef input);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
@ -297,36 +298,7 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
mlir::mhlo::ConvDimensionNumbers input) {
xla::ConvolutionDimensionNumbers output;
output.set_input_batch_dimension(
input.input_batch_dimension().getValue().getSExtValue());
output.set_input_feature_dimension(
input.input_feature_dimension().getValue().getSExtValue());
for (int64 v : input.input_spatial_dimensions().getValues<int64>()) {
output.add_input_spatial_dimensions(v);
}
output.set_kernel_input_feature_dimension(
input.kernel_input_feature_dimension().getValue().getSExtValue());
output.set_kernel_output_feature_dimension(
input.kernel_output_feature_dimension().getValue().getSExtValue());
for (int64 v : input.kernel_spatial_dimensions().getValues<int64>()) {
output.add_kernel_spatial_dimensions(v);
}
output.set_output_batch_dimension(
input.output_batch_dimension().getValue().getSExtValue());
output.set_output_feature_dimension(
input.output_feature_dimension().getValue().getSExtValue());
for (int64 v : input.output_spatial_dimensions().getValues<int64>()) {
output.add_output_spatial_dimensions(v);
}
return output;
return xla::ConvertConvDimensionNumbers(input);
}
xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {

View File

@ -270,6 +270,7 @@ cc_library(
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
"//tensorflow/compiler/mlir/xla:attribute_exporter",
"//tensorflow/compiler/mlir/xla:hlo_module_importer",
"//tensorflow/compiler/mlir/xla:hlo_utils",
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
@ -255,15 +256,17 @@ Status RunGpuConvImpl(const GpuConvParams& params,
} // anonymous namespace
StatusOr<GpuConvConfig> GetGpuConvConfig(
const HloCustomCallInstruction* cudnn_call) {
const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
GpuConvConfig config;
config.input_type = cudnn_call->operand(0)->shape().element_type();
config.output_type = cudnn_call->shape().tuple_shapes(0).element_type();
const Shape& operand0_shape = desc.operand0_shape;
const Shape& operand1_shape = desc.operand1_shape;
const Shape& result_shape = desc.result_shape;
const CudnnConvBackendConfig& backend_config = desc.backend_config;
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
cudnn_call->backend_config<CudnnConvBackendConfig>());
TF_ASSIGN_OR_RETURN(config.kind, GetCudnnConvKind(cudnn_call));
config.input_type = operand0_shape.element_type();
config.output_type = result_shape.element_type();
config.kind = desc.kind;
// The third field is scratch size stored from conv_algorithm_picker
// The operand is added to the shape field of the conv instruction
@ -271,13 +274,9 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
config.algorithm = se::dnn::AlgorithmConfig(
se::dnn::AlgorithmDesc(backend_config.algorithm(),
backend_config.tensor_ops_enabled()),
cudnn_call->shape().tuple_shapes(1).dimensions(0));
desc.scratch_size);
config.conv_result_scale = backend_config.conv_result_scale();
Shape operand0_shape = cudnn_call->operand(0)->shape();
Shape operand1_shape = cudnn_call->operand(1)->shape();
Shape result_shape = cudnn_call->shape().tuple_shapes(0);
switch (config.kind) {
case CudnnConvKind::kForward:
case CudnnConvKind::kForwardActivation:
@ -311,9 +310,8 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
fusion.side_input_scale = backend_config.side_input_scale();
}
const Window& window = cudnn_call->window();
const ConvolutionDimensionNumbers& dnums =
cudnn_call->convolution_dimension_numbers();
const Window& window = desc.window;
const ConvolutionDimensionNumbers& dnums = desc.dnums;
VLOG(3) << "Convolution Algorithm: "
<< config.algorithm.algorithm()->algo_id();
@ -330,7 +328,7 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
const int num_dimensions = window.dimensions_size();
CHECK_LE(num_dimensions, 3) << cudnn_call->ToString();
CHECK_LE(num_dimensions, 3) << inst_as_string;
// cuDNN does not support 1D convolutions. We therefore express 1D
// convolutions as 2D convolutions where the first spatial dimension is 1.
@ -344,18 +342,18 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
<< cudnn_call->ToString();
<< inst_as_string;
CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
<< cudnn_call->ToString();
<< inst_as_string;
CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
<< cudnn_call->ToString();
<< inst_as_string;
for (const WindowDimension& dim : window.dimensions()) {
CHECK_EQ(dims_reversed, dim.window_reversal()) << cudnn_call->ToString();
CHECK_EQ(dim.padding_low(), dim.padding_high()) << cudnn_call->ToString();
CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
CHECK_EQ(dim.base_dilation(), 1)
<< "cudnn does not support base dilation; it "
"must be made explicit with a kPad: "
<< cudnn_call->ToString();
<< inst_as_string;
}
// cuDNN's convolution APIs support the BDYX layout for activations/output and
@ -364,43 +362,43 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
FilterLayout filter_dl;
DataLayout output_dl;
const Shape* input_shape = &config.input_shape;
const Shape* filter_shape = &config.filter_shape;
const Shape* output_shape = &config.output_shape;
const Shape& input_shape = config.input_shape;
const Shape& filter_shape = config.filter_shape;
const Shape& output_shape = config.output_shape;
TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
XlaConvLayoutsToStreamExecutorLayouts(
dnums, input_shape->layout(), filter_shape->layout(),
output_shape->layout()));
dnums, input_shape.layout(), filter_shape.layout(),
output_shape.layout()));
BatchDescriptor& input_descriptor = config.input_descriptor;
input_descriptor = BatchDescriptor(effective_num_dimensions);
input_descriptor.set_layout(input_dl)
.set_feature_map_count(
input_shape->dimensions(dnums.input_feature_dimension()))
.set_count(input_shape->dimensions(dnums.input_batch_dimension()));
input_shape.dimensions(dnums.input_feature_dimension()))
.set_count(input_shape.dimensions(dnums.input_batch_dimension()));
for (int dim = 0; dim < num_dimensions; ++dim) {
// Note that the dimensions are reversed. The same holds below.
input_descriptor.set_spatial_dim(
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
input_shape->dimensions(dnums.input_spatial_dimensions(dim)));
input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
}
FilterDescriptor& filter_descriptor = config.filter_descriptor;
filter_descriptor = FilterDescriptor(effective_num_dimensions);
filter_descriptor.set_layout(filter_dl)
.set_input_feature_map_count(
filter_shape->dimensions(dnums.kernel_input_feature_dimension()))
filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
.set_output_feature_map_count(
filter_shape->dimensions(dnums.kernel_output_feature_dimension()));
filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
for (int dim = 0; dim < num_dimensions; ++dim) {
filter_descriptor.set_spatial_dim(
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim)));
filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
}
config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
config.conv_desc.set_group_count(cudnn_call->feature_group_count());
config.conv_desc.set_group_count(desc.feature_group_count);
config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
for (int dim = 0; dim < num_dimensions; ++dim) {
config.conv_desc
@ -419,12 +417,12 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
output_descriptor = BatchDescriptor(effective_num_dimensions);
output_descriptor.set_layout(output_dl)
.set_feature_map_count(
output_shape->dimensions(dnums.output_feature_dimension()))
.set_count(output_shape->dimensions(dnums.output_batch_dimension()));
output_shape.dimensions(dnums.output_feature_dimension()))
.set_count(output_shape.dimensions(dnums.output_batch_dimension()));
for (int dim = 0; dim < num_dimensions; ++dim) {
output_descriptor.set_spatial_dim(
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
output_shape->dimensions(dnums.output_spatial_dimensions(dim)));
output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
}
// Add a singleton dimension in the 1D convolution case.
@ -439,6 +437,23 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
return config;
}
StatusOr<GpuConvConfig> GetGpuConvConfig(
const HloCustomCallInstruction* cudnn_call) {
GpuConvDescriptor descriptor;
TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
TF_ASSIGN_OR_RETURN(descriptor.backend_config,
cudnn_call->backend_config<CudnnConvBackendConfig>());
descriptor.operand0_shape = cudnn_call->operand(0)->shape();
descriptor.operand1_shape = cudnn_call->operand(1)->shape();
descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
descriptor.window = cudnn_call->window();
descriptor.dnums = cudnn_call->convolution_dimension_numbers();
descriptor.feature_group_count = cudnn_call->feature_group_count();
return GetGpuConvConfig(descriptor, cudnn_call->ToString());
}
StatusOr<GpuConvParams> GetGpuConvParams(
const GpuConvConfig& config,
absl::Span<se::DeviceMemoryBase> operand_buffers,

View File

@ -119,9 +119,31 @@ Status RunGpuConv(const GpuConvConfig& conv_config,
se::ScratchAllocator* scratch_allocator, se::Stream* stream,
RunConvOptions = {});
// Struct to describe properties of a convolution without being tied to specific
// IR. Will be used to help build Convolution thunks from either XLA HLO or
// LHLO GPU dialect in MLIR.
struct GpuConvDescriptor {
CudnnConvKind kind;
CudnnConvBackendConfig backend_config;
Shape operand0_shape;
Shape operand1_shape;
Shape result_shape;
size_t scratch_size;
Window window;
ConvolutionDimensionNumbers dnums;
int64 feature_group_count;
};
// Returns the convolution configuration given a XLA HLO instruction.
StatusOr<GpuConvConfig> GetGpuConvConfig(
const HloCustomCallInstruction* cudnn_call);
// Returns the convolution configuration given a convolution descriptor `desc`
// and a string representation of the convolution instruction `inst_as_string`
// (for error reporting).
StatusOr<GpuConvConfig> GetGpuConvConfig(const GpuConvDescriptor& desc,
absl::string_view inst_as_string);
// Implementation details exposed for debugging and log analysis.
StatusOr<GpuConvParams> GetGpuConvParams(
const GpuConvConfig& conv_config,

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
@ -62,6 +63,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
@ -1046,8 +1048,12 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
return EmitGemmThunkFromMlir(input);
}
if (IsCustomCallToDnnConvolution(*custom_call)) {
return ThunkEmitter(this).HandleCustomCall(custom_call);
if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
mlir::lmhlo_gpu::ConvForwardFusedOp,
mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
mlir::lmhlo_gpu::ConvBackwardFilterOp,
mlir::lmhlo_gpu::ConvBackwardInputOp>(input.op)) {
return EmitConvolutionThunkFromMlir(input);
}
#if GOOGLE_CUDA
@ -1060,6 +1066,118 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
custom_call->custom_call_target());
}
Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
using mlir::dyn_cast;
using mlir::lmhlo_gpu::Activation;
using mlir::lmhlo_gpu::ConvBackwardFilterOp;
using mlir::lmhlo_gpu::ConvBackwardInputOp;
using mlir::lmhlo_gpu::ConvForwardFusedOp;
using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
using mlir::lmhlo_gpu::ConvForwardOp;
// Last 2 operands of the convolution operation are the result and scratch.
std::vector<BufferAllocation::Slice> operand_slices;
int64 num_operands = input.op->getNumOperands();
operand_slices.reserve(num_operands - 2);
for (mlir::Value operand : input.op->getOperands().drop_back(2)) {
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand));
operand_slices.push_back(slice);
}
mlir::Value conv_result = input.op->getOperand(num_operands - 2);
mlir::Value scratch_result = input.op->getOperand(num_operands - 1);
TF_ASSIGN_OR_RETURN(auto conv_result_slice,
GetAllocationSliceForMlir(conv_result));
TF_ASSIGN_OR_RETURN(auto scratch_slice,
GetAllocationSliceForMlir(scratch_result));
auto apply_layout = [](const Shape& shape, mlir::ArrayAttr layout_attrib) {
mlir::SmallVector<int64, 4> minor_to_major = llvm::to_vector<4>(
llvm::map_range(layout_attrib, [](mlir::Attribute a) -> int64 {
return static_cast<int64>(a.cast<mlir::IntegerAttr>().getInt());
}));
return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
shape.dimensions(), minor_to_major);
};
GpuConvDescriptor descriptor;
auto fill_conv_descriptor = [&](auto op) {
descriptor.operand0_shape =
apply_layout(TypeToShape(input.op->getOperand(0).getType()),
op.backend_config().operand_0_layout());
descriptor.operand1_shape =
apply_layout(TypeToShape(input.op->getOperand(1).getType()),
op.backend_config().operand_1_layout());
descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
op.backend_config().result_layout());
descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
mlir::DenseIntElementsAttr padding = op.padding().getValue();
mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue();
mlir::DenseElementsAttr window_reversal = op.window_reversal().getValue();
for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
WindowDimension* dim = descriptor.window.add_dimensions();
// Window size for a convolution is the same as the kernel size.
// Kernel size of the convolution is operand1_shape. We need to look at
// the convolution dimension numbers kernel spatial dimensions to get
// the window size.
int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
dim->set_stride(window_strides.getValue<int64>(index));
dim->set_padding_low(padding.getValue<int64>(index));
dim->set_padding_high(padding.getValue<int64>(index));
dim->set_base_dilation(lhs_dilation.getValue<int64>(index));
dim->set_window_dilation(rhs_dilation.getValue<int64>(index));
dim->set_window_reversal(window_reversal.getValue<bool>(index));
}
descriptor.feature_group_count = op.feature_group_count();
descriptor.backend_config.set_algorithm(
op.backend_config().algorithm().getInt());
descriptor.backend_config.set_tensor_ops_enabled(
op.backend_config().tensor_ops_enabled().getValue());
descriptor.backend_config.set_conv_result_scale(
op.result_scale().convertToDouble());
};
auto set_activation_mode = [&](auto op) -> Status {
TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
ConvertConvActivationMode(op.activation_mode()));
descriptor.backend_config.set_activation_mode(
static_cast<int64>(activation_mode));
return Status::OK();
};
if (auto op = dyn_cast<ConvForwardOp>(input.op)) {
descriptor.kind = CudnnConvKind::kForward;
fill_conv_descriptor(op);
} else if (auto op = dyn_cast<ConvBackwardInputOp>(input.op)) {
descriptor.kind = CudnnConvKind::kBackwardInput;
fill_conv_descriptor(op);
} else if (auto op = dyn_cast<ConvBackwardFilterOp>(input.op)) {
descriptor.kind = CudnnConvKind::kBackwardFilter;
fill_conv_descriptor(op);
} else if (auto op = dyn_cast<ConvForwardFusedOp>(input.op)) {
descriptor.kind = CudnnConvKind::kForwardActivation;
fill_conv_descriptor(op);
TF_RETURN_IF_ERROR(set_activation_mode(op));
} else if (auto op = dyn_cast<ConvForwardFusedSideInputOp>(input.op)) {
descriptor.kind = CudnnConvKind::kForwardActivation;
fill_conv_descriptor(op);
TF_RETURN_IF_ERROR(set_activation_mode(op));
descriptor.backend_config.set_side_input_scale(
op.side_input_scale().convertToDouble());
} else {
return InternalError("Unexpected operation");
}
TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
input.thunk_info, std::move(config), std::move(operand_slices),
conv_result_slice, scratch_slice));
return Status::OK();
}
Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) {
auto build_gemm_config = [](auto op) {
GpuGemmConfig config;

View File

@ -168,6 +168,7 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleConditional(HloInstruction* conditional) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status EmitConvolutionThunkFromMlir(MlirEmitterInput input);
Status EmitGemmThunkFromMlir(MlirEmitterInput input);
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
Status EmitCholeskyThunkFromMlir(MlirEmitterInput input);