[XLA:GPU] Migrate convolution thunk emission to MLIR
- Map MLIR LHLO Conv operations to ConvolutionThunk. - Refactor GetGpuConvConfig() to take in a conv descriptor, so that the same function can be used to emit a thunk from XLA HLO and MLIR representations. - Change XlaConvLayoutsToStreamExecutorLayouts() to check shape<->layout compatibility instead of layout equality, since some XLA layouts with unit dimensions are not preserved in XLA HLO Shape -> MLIR -> XLA HLO Shape path. - Note that window reversal is not yet representable in HLO/LHLO, so fall back to XLA HLO based thunk emission for that case. PiperOrigin-RevId: 348051790 Change-Id: I46cf3c08bb8d2440481b570146cb4d7607d2843a
This commit is contained in:
parent
869022b4d6
commit
8684c6b2e9
tensorflow/compiler
@ -250,6 +250,7 @@ cc_library(
|
||||
],
|
||||
hdrs = ["mlir_hlo_to_hlo.h"],
|
||||
deps = [
|
||||
":attribute_exporter",
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/mlir:name_utils",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
@ -338,6 +339,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attribute_exporter",
|
||||
srcs = ["attribute_exporter.cc"],
|
||||
hdrs = ["attribute_exporter.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:dnn",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "translate_cl_options",
|
||||
srcs = ["xla_mlir_translate_cl.cc"],
|
||||
|
87
tensorflow/compiler/mlir/xla/attribute_exporter.cc
Normal file
87
tensorflow/compiler/mlir/xla/attribute_exporter.cc
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
|
||||
mlir::mhlo::ConvDimensionNumbers input) {
|
||||
ConvolutionDimensionNumbers output;
|
||||
|
||||
output.set_input_batch_dimension(
|
||||
input.input_batch_dimension().getValue().getSExtValue());
|
||||
output.set_input_feature_dimension(
|
||||
input.input_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (auto v : input.input_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_input_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_kernel_input_feature_dimension(
|
||||
input.kernel_input_feature_dimension().getValue().getSExtValue());
|
||||
output.set_kernel_output_feature_dimension(
|
||||
input.kernel_output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (auto v : input.kernel_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_kernel_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_output_batch_dimension(
|
||||
input.output_batch_dimension().getValue().getSExtValue());
|
||||
output.set_output_feature_dimension(
|
||||
input.output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (auto v : input.output_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_output_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
|
||||
llvm::StringRef input) {
|
||||
llvm::Optional<mlir::lmhlo_gpu::Activation> activation =
|
||||
mlir::lmhlo_gpu::symbolizeActivation(input);
|
||||
if (!activation) {
|
||||
return InternalError("Unexpected activation");
|
||||
}
|
||||
|
||||
switch (activation.getValue()) {
|
||||
case mlir::lmhlo_gpu::Activation::None:
|
||||
return stream_executor::dnn::kNone;
|
||||
case mlir::lmhlo_gpu::Activation::Sigmoid:
|
||||
return stream_executor::dnn::kSigmoid;
|
||||
case mlir::lmhlo_gpu::Activation::Tanh:
|
||||
return stream_executor::dnn::kTanh;
|
||||
case mlir::lmhlo_gpu::Activation::Relu:
|
||||
return stream_executor::dnn::kRelu;
|
||||
case mlir::lmhlo_gpu::Activation::Relu6:
|
||||
return stream_executor::dnn::kRelu6;
|
||||
case mlir::lmhlo_gpu::Activation::ReluX:
|
||||
return stream_executor::dnn::kReluX;
|
||||
case mlir::lmhlo_gpu::Activation::BandPass:
|
||||
return stream_executor::dnn::kBandPass;
|
||||
default:
|
||||
return InternalError("Unexpected activation");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
37
tensorflow/compiler/mlir/xla/attribute_exporter.h
Normal file
37
tensorflow/compiler/mlir/xla/attribute_exporter.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Converts the conv dimensions attribute to XLA HLO.
|
||||
ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
|
||||
mlir::mhlo::ConvDimensionNumbers input);
|
||||
|
||||
StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
|
||||
llvm::StringRef input);
|
||||
|
||||
} // namespace xla
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
@ -297,36 +298,7 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
|
||||
|
||||
static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
|
||||
mlir::mhlo::ConvDimensionNumbers input) {
|
||||
xla::ConvolutionDimensionNumbers output;
|
||||
|
||||
output.set_input_batch_dimension(
|
||||
input.input_batch_dimension().getValue().getSExtValue());
|
||||
output.set_input_feature_dimension(
|
||||
input.input_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (int64 v : input.input_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_input_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_kernel_input_feature_dimension(
|
||||
input.kernel_input_feature_dimension().getValue().getSExtValue());
|
||||
output.set_kernel_output_feature_dimension(
|
||||
input.kernel_output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (int64 v : input.kernel_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_kernel_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_output_batch_dimension(
|
||||
input.output_batch_dimension().getValue().getSExtValue());
|
||||
output.set_output_feature_dimension(
|
||||
input.output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (int64 v : input.output_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_output_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
return output;
|
||||
return xla::ConvertConvDimensionNumbers(input);
|
||||
}
|
||||
|
||||
xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
|
||||
|
@ -270,6 +270,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:attribute_exporter",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_module_importer",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_utils",
|
||||
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
@ -255,15 +256,17 @@ Status RunGpuConvImpl(const GpuConvParams& params,
|
||||
} // anonymous namespace
|
||||
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
const HloCustomCallInstruction* cudnn_call) {
|
||||
const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
|
||||
GpuConvConfig config;
|
||||
|
||||
config.input_type = cudnn_call->operand(0)->shape().element_type();
|
||||
config.output_type = cudnn_call->shape().tuple_shapes(0).element_type();
|
||||
const Shape& operand0_shape = desc.operand0_shape;
|
||||
const Shape& operand1_shape = desc.operand1_shape;
|
||||
const Shape& result_shape = desc.result_shape;
|
||||
const CudnnConvBackendConfig& backend_config = desc.backend_config;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
|
||||
cudnn_call->backend_config<CudnnConvBackendConfig>());
|
||||
TF_ASSIGN_OR_RETURN(config.kind, GetCudnnConvKind(cudnn_call));
|
||||
config.input_type = operand0_shape.element_type();
|
||||
config.output_type = result_shape.element_type();
|
||||
config.kind = desc.kind;
|
||||
|
||||
// The third field is scratch size stored from conv_algorithm_picker
|
||||
// The operand is added to the shape field of the conv instruction
|
||||
@ -271,13 +274,9 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
config.algorithm = se::dnn::AlgorithmConfig(
|
||||
se::dnn::AlgorithmDesc(backend_config.algorithm(),
|
||||
backend_config.tensor_ops_enabled()),
|
||||
cudnn_call->shape().tuple_shapes(1).dimensions(0));
|
||||
desc.scratch_size);
|
||||
config.conv_result_scale = backend_config.conv_result_scale();
|
||||
|
||||
Shape operand0_shape = cudnn_call->operand(0)->shape();
|
||||
Shape operand1_shape = cudnn_call->operand(1)->shape();
|
||||
Shape result_shape = cudnn_call->shape().tuple_shapes(0);
|
||||
|
||||
switch (config.kind) {
|
||||
case CudnnConvKind::kForward:
|
||||
case CudnnConvKind::kForwardActivation:
|
||||
@ -311,9 +310,8 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
fusion.side_input_scale = backend_config.side_input_scale();
|
||||
}
|
||||
|
||||
const Window& window = cudnn_call->window();
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
cudnn_call->convolution_dimension_numbers();
|
||||
const Window& window = desc.window;
|
||||
const ConvolutionDimensionNumbers& dnums = desc.dnums;
|
||||
|
||||
VLOG(3) << "Convolution Algorithm: "
|
||||
<< config.algorithm.algorithm()->algo_id();
|
||||
@ -330,7 +328,7 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
|
||||
|
||||
const int num_dimensions = window.dimensions_size();
|
||||
CHECK_LE(num_dimensions, 3) << cudnn_call->ToString();
|
||||
CHECK_LE(num_dimensions, 3) << inst_as_string;
|
||||
|
||||
// cuDNN does not support 1D convolutions. We therefore express 1D
|
||||
// convolutions as 2D convolutions where the first spatial dimension is 1.
|
||||
@ -344,18 +342,18 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
|
||||
|
||||
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
for (const WindowDimension& dim : window.dimensions()) {
|
||||
CHECK_EQ(dims_reversed, dim.window_reversal()) << cudnn_call->ToString();
|
||||
CHECK_EQ(dim.padding_low(), dim.padding_high()) << cudnn_call->ToString();
|
||||
CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
|
||||
CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
|
||||
CHECK_EQ(dim.base_dilation(), 1)
|
||||
<< "cudnn does not support base dilation; it "
|
||||
"must be made explicit with a kPad: "
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
}
|
||||
|
||||
// cuDNN's convolution APIs support the BDYX layout for activations/output and
|
||||
@ -364,43 +362,43 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
FilterLayout filter_dl;
|
||||
DataLayout output_dl;
|
||||
|
||||
const Shape* input_shape = &config.input_shape;
|
||||
const Shape* filter_shape = &config.filter_shape;
|
||||
const Shape* output_shape = &config.output_shape;
|
||||
const Shape& input_shape = config.input_shape;
|
||||
const Shape& filter_shape = config.filter_shape;
|
||||
const Shape& output_shape = config.output_shape;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
|
||||
XlaConvLayoutsToStreamExecutorLayouts(
|
||||
dnums, input_shape->layout(), filter_shape->layout(),
|
||||
output_shape->layout()));
|
||||
dnums, input_shape.layout(), filter_shape.layout(),
|
||||
output_shape.layout()));
|
||||
|
||||
BatchDescriptor& input_descriptor = config.input_descriptor;
|
||||
input_descriptor = BatchDescriptor(effective_num_dimensions);
|
||||
input_descriptor.set_layout(input_dl)
|
||||
.set_feature_map_count(
|
||||
input_shape->dimensions(dnums.input_feature_dimension()))
|
||||
.set_count(input_shape->dimensions(dnums.input_batch_dimension()));
|
||||
input_shape.dimensions(dnums.input_feature_dimension()))
|
||||
.set_count(input_shape.dimensions(dnums.input_batch_dimension()));
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
// Note that the dimensions are reversed. The same holds below.
|
||||
input_descriptor.set_spatial_dim(
|
||||
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
|
||||
input_shape->dimensions(dnums.input_spatial_dimensions(dim)));
|
||||
input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
|
||||
}
|
||||
|
||||
FilterDescriptor& filter_descriptor = config.filter_descriptor;
|
||||
filter_descriptor = FilterDescriptor(effective_num_dimensions);
|
||||
filter_descriptor.set_layout(filter_dl)
|
||||
.set_input_feature_map_count(
|
||||
filter_shape->dimensions(dnums.kernel_input_feature_dimension()))
|
||||
filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
|
||||
.set_output_feature_map_count(
|
||||
filter_shape->dimensions(dnums.kernel_output_feature_dimension()));
|
||||
filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
filter_descriptor.set_spatial_dim(
|
||||
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
|
||||
filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim)));
|
||||
filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
|
||||
}
|
||||
|
||||
config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
|
||||
config.conv_desc.set_group_count(cudnn_call->feature_group_count());
|
||||
config.conv_desc.set_group_count(desc.feature_group_count);
|
||||
config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
config.conv_desc
|
||||
@ -419,12 +417,12 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
output_descriptor = BatchDescriptor(effective_num_dimensions);
|
||||
output_descriptor.set_layout(output_dl)
|
||||
.set_feature_map_count(
|
||||
output_shape->dimensions(dnums.output_feature_dimension()))
|
||||
.set_count(output_shape->dimensions(dnums.output_batch_dimension()));
|
||||
output_shape.dimensions(dnums.output_feature_dimension()))
|
||||
.set_count(output_shape.dimensions(dnums.output_batch_dimension()));
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
output_descriptor.set_spatial_dim(
|
||||
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
|
||||
output_shape->dimensions(dnums.output_spatial_dimensions(dim)));
|
||||
output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
|
||||
}
|
||||
|
||||
// Add a singleton dimension in the 1D convolution case.
|
||||
@ -439,6 +437,23 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
return config;
|
||||
}
|
||||
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
const HloCustomCallInstruction* cudnn_call) {
|
||||
GpuConvDescriptor descriptor;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
|
||||
TF_ASSIGN_OR_RETURN(descriptor.backend_config,
|
||||
cudnn_call->backend_config<CudnnConvBackendConfig>());
|
||||
descriptor.operand0_shape = cudnn_call->operand(0)->shape();
|
||||
descriptor.operand1_shape = cudnn_call->operand(1)->shape();
|
||||
descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
|
||||
descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
|
||||
descriptor.window = cudnn_call->window();
|
||||
descriptor.dnums = cudnn_call->convolution_dimension_numbers();
|
||||
descriptor.feature_group_count = cudnn_call->feature_group_count();
|
||||
return GetGpuConvConfig(descriptor, cudnn_call->ToString());
|
||||
}
|
||||
|
||||
StatusOr<GpuConvParams> GetGpuConvParams(
|
||||
const GpuConvConfig& config,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
|
@ -119,9 +119,31 @@ Status RunGpuConv(const GpuConvConfig& conv_config,
|
||||
se::ScratchAllocator* scratch_allocator, se::Stream* stream,
|
||||
RunConvOptions = {});
|
||||
|
||||
// Struct to describe properties of a convolution without being tied to specific
|
||||
// IR. Will be used to help build Convolution thunks from either XLA HLO or
|
||||
// LHLO GPU dialect in MLIR.
|
||||
struct GpuConvDescriptor {
|
||||
CudnnConvKind kind;
|
||||
CudnnConvBackendConfig backend_config;
|
||||
Shape operand0_shape;
|
||||
Shape operand1_shape;
|
||||
Shape result_shape;
|
||||
size_t scratch_size;
|
||||
Window window;
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
int64 feature_group_count;
|
||||
};
|
||||
|
||||
// Returns the convolution configuration given a XLA HLO instruction.
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
const HloCustomCallInstruction* cudnn_call);
|
||||
|
||||
// Returns the convolution configuration given a convolution descriptor `desc`
|
||||
// and a string representation of the convolution instruction `inst_as_string`
|
||||
// (for error reporting).
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(const GpuConvDescriptor& desc,
|
||||
absl::string_view inst_as_string);
|
||||
|
||||
// Implementation details exposed for debugging and log analysis.
|
||||
StatusOr<GpuConvParams> GetGpuConvParams(
|
||||
const GpuConvConfig& conv_config,
|
||||
|
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
@ -62,6 +63,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
|
||||
@ -1046,8 +1048,12 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return EmitGemmThunkFromMlir(input);
|
||||
}
|
||||
|
||||
if (IsCustomCallToDnnConvolution(*custom_call)) {
|
||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||
if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
|
||||
mlir::lmhlo_gpu::ConvForwardFusedOp,
|
||||
mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
|
||||
mlir::lmhlo_gpu::ConvBackwardFilterOp,
|
||||
mlir::lmhlo_gpu::ConvBackwardInputOp>(input.op)) {
|
||||
return EmitConvolutionThunkFromMlir(input);
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -1060,6 +1066,118 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
custom_call->custom_call_target());
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
|
||||
using mlir::dyn_cast;
|
||||
using mlir::lmhlo_gpu::Activation;
|
||||
using mlir::lmhlo_gpu::ConvBackwardFilterOp;
|
||||
using mlir::lmhlo_gpu::ConvBackwardInputOp;
|
||||
using mlir::lmhlo_gpu::ConvForwardFusedOp;
|
||||
using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
|
||||
using mlir::lmhlo_gpu::ConvForwardOp;
|
||||
|
||||
// Last 2 operands of the convolution operation are the result and scratch.
|
||||
std::vector<BufferAllocation::Slice> operand_slices;
|
||||
int64 num_operands = input.op->getNumOperands();
|
||||
operand_slices.reserve(num_operands - 2);
|
||||
for (mlir::Value operand : input.op->getOperands().drop_back(2)) {
|
||||
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand));
|
||||
operand_slices.push_back(slice);
|
||||
}
|
||||
|
||||
mlir::Value conv_result = input.op->getOperand(num_operands - 2);
|
||||
mlir::Value scratch_result = input.op->getOperand(num_operands - 1);
|
||||
TF_ASSIGN_OR_RETURN(auto conv_result_slice,
|
||||
GetAllocationSliceForMlir(conv_result));
|
||||
TF_ASSIGN_OR_RETURN(auto scratch_slice,
|
||||
GetAllocationSliceForMlir(scratch_result));
|
||||
|
||||
auto apply_layout = [](const Shape& shape, mlir::ArrayAttr layout_attrib) {
|
||||
mlir::SmallVector<int64, 4> minor_to_major = llvm::to_vector<4>(
|
||||
llvm::map_range(layout_attrib, [](mlir::Attribute a) -> int64 {
|
||||
return static_cast<int64>(a.cast<mlir::IntegerAttr>().getInt());
|
||||
}));
|
||||
return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
|
||||
shape.dimensions(), minor_to_major);
|
||||
};
|
||||
|
||||
GpuConvDescriptor descriptor;
|
||||
|
||||
auto fill_conv_descriptor = [&](auto op) {
|
||||
descriptor.operand0_shape =
|
||||
apply_layout(TypeToShape(input.op->getOperand(0).getType()),
|
||||
op.backend_config().operand_0_layout());
|
||||
descriptor.operand1_shape =
|
||||
apply_layout(TypeToShape(input.op->getOperand(1).getType()),
|
||||
op.backend_config().operand_1_layout());
|
||||
descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
|
||||
op.backend_config().result_layout());
|
||||
descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
|
||||
mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
|
||||
mlir::DenseIntElementsAttr padding = op.padding().getValue();
|
||||
mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
|
||||
mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue();
|
||||
mlir::DenseElementsAttr window_reversal = op.window_reversal().getValue();
|
||||
for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
|
||||
WindowDimension* dim = descriptor.window.add_dimensions();
|
||||
// Window size for a convolution is the same as the kernel size.
|
||||
// Kernel size of the convolution is operand1_shape. We need to look at
|
||||
// the convolution dimension numbers kernel spatial dimensions to get
|
||||
// the window size.
|
||||
int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
|
||||
dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
|
||||
dim->set_stride(window_strides.getValue<int64>(index));
|
||||
dim->set_padding_low(padding.getValue<int64>(index));
|
||||
dim->set_padding_high(padding.getValue<int64>(index));
|
||||
dim->set_base_dilation(lhs_dilation.getValue<int64>(index));
|
||||
dim->set_window_dilation(rhs_dilation.getValue<int64>(index));
|
||||
dim->set_window_reversal(window_reversal.getValue<bool>(index));
|
||||
}
|
||||
descriptor.feature_group_count = op.feature_group_count();
|
||||
descriptor.backend_config.set_algorithm(
|
||||
op.backend_config().algorithm().getInt());
|
||||
descriptor.backend_config.set_tensor_ops_enabled(
|
||||
op.backend_config().tensor_ops_enabled().getValue());
|
||||
descriptor.backend_config.set_conv_result_scale(
|
||||
op.result_scale().convertToDouble());
|
||||
};
|
||||
|
||||
auto set_activation_mode = [&](auto op) -> Status {
|
||||
TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
|
||||
ConvertConvActivationMode(op.activation_mode()));
|
||||
descriptor.backend_config.set_activation_mode(
|
||||
static_cast<int64>(activation_mode));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
if (auto op = dyn_cast<ConvForwardOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kForward;
|
||||
fill_conv_descriptor(op);
|
||||
} else if (auto op = dyn_cast<ConvBackwardInputOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kBackwardInput;
|
||||
fill_conv_descriptor(op);
|
||||
} else if (auto op = dyn_cast<ConvBackwardFilterOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kBackwardFilter;
|
||||
fill_conv_descriptor(op);
|
||||
} else if (auto op = dyn_cast<ConvForwardFusedOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kForwardActivation;
|
||||
fill_conv_descriptor(op);
|
||||
TF_RETURN_IF_ERROR(set_activation_mode(op));
|
||||
} else if (auto op = dyn_cast<ConvForwardFusedSideInputOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kForwardActivation;
|
||||
fill_conv_descriptor(op);
|
||||
TF_RETURN_IF_ERROR(set_activation_mode(op));
|
||||
descriptor.backend_config.set_side_input_scale(
|
||||
op.side_input_scale().convertToDouble());
|
||||
} else {
|
||||
return InternalError("Unexpected operation");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
|
||||
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
|
||||
input.thunk_info, std::move(config), std::move(operand_slices),
|
||||
conv_result_slice, scratch_slice));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) {
|
||||
auto build_gemm_config = [](auto op) {
|
||||
GpuGemmConfig config;
|
||||
|
@ -168,6 +168,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleConditional(HloInstruction* conditional) override;
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||
Status EmitConvolutionThunkFromMlir(MlirEmitterInput input);
|
||||
Status EmitGemmThunkFromMlir(MlirEmitterInput input);
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||
Status EmitCholeskyThunkFromMlir(MlirEmitterInput input);
|
||||
|
Loading…
Reference in New Issue
Block a user