[XLA:GPU] Migrate convolution thunk emission to MLIR
- Map MLIR LHLO Conv operations to ConvolutionThunk. - Refactor GetGpuConvConfig() to take in a conv descriptor, so that the same function can be used to emit a thunk from XLA HLO and MLIR representations. - Change XlaConvLayoutsToStreamExecutorLayouts() to check shape<->layout compatibility instead of layout equality, since some XLA layouts with unit dimensions are not preserved in XLA HLO Shape -> MLIR -> XLA HLO Shape path. - Note that window reversal is not yet representable in HLO/LHLO, so fall back to XLA HLO based thunk emission for that case. PiperOrigin-RevId: 346885084 Change-Id: Idf8ae057863a46b56feecb145ebaf4152a8b5c5d
This commit is contained in:
parent
816bb157e9
commit
ece423eb03
@ -249,6 +249,7 @@ cc_library(
|
||||
],
|
||||
hdrs = ["mlir_hlo_to_hlo.h"],
|
||||
deps = [
|
||||
":attribute_exporter",
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/mlir:name_utils",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
@ -337,6 +338,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attribute_exporter",
|
||||
srcs = ["attribute_exporter.cc"],
|
||||
hdrs = ["attribute_exporter.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:dnn",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "translate_cl_options",
|
||||
srcs = ["xla_mlir_translate_cl.cc"],
|
||||
|
87
tensorflow/compiler/mlir/xla/attribute_exporter.cc
Normal file
87
tensorflow/compiler/mlir/xla/attribute_exporter.cc
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
|
||||
mlir::mhlo::ConvDimensionNumbers input) {
|
||||
ConvolutionDimensionNumbers output;
|
||||
|
||||
output.set_input_batch_dimension(
|
||||
input.input_batch_dimension().getValue().getSExtValue());
|
||||
output.set_input_feature_dimension(
|
||||
input.input_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (auto v : input.input_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_input_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_kernel_input_feature_dimension(
|
||||
input.kernel_input_feature_dimension().getValue().getSExtValue());
|
||||
output.set_kernel_output_feature_dimension(
|
||||
input.kernel_output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (auto v : input.kernel_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_kernel_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_output_batch_dimension(
|
||||
input.output_batch_dimension().getValue().getSExtValue());
|
||||
output.set_output_feature_dimension(
|
||||
input.output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (auto v : input.output_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_output_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
|
||||
llvm::StringRef input) {
|
||||
llvm::Optional<mlir::lmhlo_gpu::Activation> activation =
|
||||
mlir::lmhlo_gpu::symbolizeActivation(input);
|
||||
if (!activation) {
|
||||
return InternalError("Unexpected activation");
|
||||
}
|
||||
|
||||
switch (activation.getValue()) {
|
||||
case mlir::lmhlo_gpu::Activation::None:
|
||||
return stream_executor::dnn::kNone;
|
||||
case mlir::lmhlo_gpu::Activation::Sigmoid:
|
||||
return stream_executor::dnn::kSigmoid;
|
||||
case mlir::lmhlo_gpu::Activation::Tanh:
|
||||
return stream_executor::dnn::kTanh;
|
||||
case mlir::lmhlo_gpu::Activation::Relu:
|
||||
return stream_executor::dnn::kRelu;
|
||||
case mlir::lmhlo_gpu::Activation::Relu6:
|
||||
return stream_executor::dnn::kRelu6;
|
||||
case mlir::lmhlo_gpu::Activation::ReluX:
|
||||
return stream_executor::dnn::kReluX;
|
||||
case mlir::lmhlo_gpu::Activation::BandPass:
|
||||
return stream_executor::dnn::kBandPass;
|
||||
default:
|
||||
return InternalError("Unexpected activation");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
37
tensorflow/compiler/mlir/xla/attribute_exporter.h
Normal file
37
tensorflow/compiler/mlir/xla/attribute_exporter.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Converts the conv dimensions attribute to XLA HLO.
|
||||
ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
|
||||
mlir::mhlo::ConvDimensionNumbers input);
|
||||
|
||||
StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
|
||||
llvm::StringRef input);
|
||||
|
||||
} // namespace xla
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
@ -297,36 +298,7 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
|
||||
|
||||
static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
|
||||
mlir::mhlo::ConvDimensionNumbers input) {
|
||||
xla::ConvolutionDimensionNumbers output;
|
||||
|
||||
output.set_input_batch_dimension(
|
||||
input.input_batch_dimension().getValue().getSExtValue());
|
||||
output.set_input_feature_dimension(
|
||||
input.input_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (int64 v : input.input_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_input_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_kernel_input_feature_dimension(
|
||||
input.kernel_input_feature_dimension().getValue().getSExtValue());
|
||||
output.set_kernel_output_feature_dimension(
|
||||
input.kernel_output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (int64 v : input.kernel_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_kernel_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_output_batch_dimension(
|
||||
input.output_batch_dimension().getValue().getSExtValue());
|
||||
output.set_output_feature_dimension(
|
||||
input.output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (int64 v : input.output_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_output_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
return output;
|
||||
return xla::ConvertConvDimensionNumbers(input);
|
||||
}
|
||||
|
||||
xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
|
||||
|
@ -270,6 +270,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:attribute_exporter",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_module_importer",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_utils",
|
||||
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
@ -255,15 +256,17 @@ Status RunGpuConvImpl(const GpuConvParams& params,
|
||||
} // anonymous namespace
|
||||
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
const HloCustomCallInstruction* cudnn_call) {
|
||||
const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
|
||||
GpuConvConfig config;
|
||||
|
||||
config.input_type = cudnn_call->operand(0)->shape().element_type();
|
||||
config.output_type = cudnn_call->shape().tuple_shapes(0).element_type();
|
||||
const Shape& operand0_shape = desc.operand0_shape;
|
||||
const Shape& operand1_shape = desc.operand1_shape;
|
||||
const Shape& result_shape = desc.result_shape;
|
||||
const CudnnConvBackendConfig& backend_config = desc.backend_config;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
|
||||
cudnn_call->backend_config<CudnnConvBackendConfig>());
|
||||
TF_ASSIGN_OR_RETURN(config.kind, GetCudnnConvKind(cudnn_call));
|
||||
config.input_type = operand0_shape.element_type();
|
||||
config.output_type = result_shape.element_type();
|
||||
config.kind = desc.kind;
|
||||
|
||||
// The third field is scratch size stored from conv_algorithm_picker
|
||||
// The operand is added to the shape field of the conv instruction
|
||||
@ -271,13 +274,9 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
config.algorithm = se::dnn::AlgorithmConfig(
|
||||
se::dnn::AlgorithmDesc(backend_config.algorithm(),
|
||||
backend_config.tensor_ops_enabled()),
|
||||
cudnn_call->shape().tuple_shapes(1).dimensions(0));
|
||||
desc.scratch_size);
|
||||
config.conv_result_scale = backend_config.conv_result_scale();
|
||||
|
||||
Shape operand0_shape = cudnn_call->operand(0)->shape();
|
||||
Shape operand1_shape = cudnn_call->operand(1)->shape();
|
||||
Shape result_shape = cudnn_call->shape().tuple_shapes(0);
|
||||
|
||||
switch (config.kind) {
|
||||
case CudnnConvKind::kForward:
|
||||
case CudnnConvKind::kForwardActivation:
|
||||
@ -311,9 +310,8 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
fusion.side_input_scale = backend_config.side_input_scale();
|
||||
}
|
||||
|
||||
const Window& window = cudnn_call->window();
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
cudnn_call->convolution_dimension_numbers();
|
||||
const Window& window = desc.window;
|
||||
const ConvolutionDimensionNumbers& dnums = desc.dnums;
|
||||
|
||||
VLOG(3) << "Convolution Algorithm: "
|
||||
<< config.algorithm.algorithm()->algo_id();
|
||||
@ -330,7 +328,7 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
|
||||
|
||||
const int num_dimensions = window.dimensions_size();
|
||||
CHECK_LE(num_dimensions, 3) << cudnn_call->ToString();
|
||||
CHECK_LE(num_dimensions, 3) << inst_as_string;
|
||||
|
||||
// cuDNN does not support 1D convolutions. We therefore express 1D
|
||||
// convolutions as 2D convolutions where the first spatial dimension is 1.
|
||||
@ -344,18 +342,18 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
|
||||
|
||||
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
for (const WindowDimension& dim : window.dimensions()) {
|
||||
CHECK_EQ(dims_reversed, dim.window_reversal()) << cudnn_call->ToString();
|
||||
CHECK_EQ(dim.padding_low(), dim.padding_high()) << cudnn_call->ToString();
|
||||
CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
|
||||
CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
|
||||
CHECK_EQ(dim.base_dilation(), 1)
|
||||
<< "cudnn does not support base dilation; it "
|
||||
"must be made explicit with a kPad: "
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
}
|
||||
|
||||
// cuDNN's convolution APIs support the BDYX layout for activations/output and
|
||||
@ -364,43 +362,42 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
FilterLayout filter_dl;
|
||||
DataLayout output_dl;
|
||||
|
||||
const Shape* input_shape = &config.input_shape;
|
||||
const Shape* filter_shape = &config.filter_shape;
|
||||
const Shape* output_shape = &config.output_shape;
|
||||
const Shape& input_shape = config.input_shape;
|
||||
const Shape& filter_shape = config.filter_shape;
|
||||
const Shape& output_shape = config.output_shape;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
|
||||
XlaConvLayoutsToStreamExecutorLayouts(
|
||||
dnums, input_shape->layout(), filter_shape->layout(),
|
||||
output_shape->layout()));
|
||||
dnums, input_shape, filter_shape, output_shape));
|
||||
|
||||
BatchDescriptor& input_descriptor = config.input_descriptor;
|
||||
input_descriptor = BatchDescriptor(effective_num_dimensions);
|
||||
input_descriptor.set_layout(input_dl)
|
||||
.set_feature_map_count(
|
||||
input_shape->dimensions(dnums.input_feature_dimension()))
|
||||
.set_count(input_shape->dimensions(dnums.input_batch_dimension()));
|
||||
input_shape.dimensions(dnums.input_feature_dimension()))
|
||||
.set_count(input_shape.dimensions(dnums.input_batch_dimension()));
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
// Note that the dimensions are reversed. The same holds below.
|
||||
input_descriptor.set_spatial_dim(
|
||||
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
|
||||
input_shape->dimensions(dnums.input_spatial_dimensions(dim)));
|
||||
input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
|
||||
}
|
||||
|
||||
FilterDescriptor& filter_descriptor = config.filter_descriptor;
|
||||
filter_descriptor = FilterDescriptor(effective_num_dimensions);
|
||||
filter_descriptor.set_layout(filter_dl)
|
||||
.set_input_feature_map_count(
|
||||
filter_shape->dimensions(dnums.kernel_input_feature_dimension()))
|
||||
filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
|
||||
.set_output_feature_map_count(
|
||||
filter_shape->dimensions(dnums.kernel_output_feature_dimension()));
|
||||
filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
filter_descriptor.set_spatial_dim(
|
||||
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
|
||||
filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim)));
|
||||
filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
|
||||
}
|
||||
|
||||
config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
|
||||
config.conv_desc.set_group_count(cudnn_call->feature_group_count());
|
||||
config.conv_desc.set_group_count(desc.feature_group_count);
|
||||
config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
config.conv_desc
|
||||
@ -419,12 +416,12 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
output_descriptor = BatchDescriptor(effective_num_dimensions);
|
||||
output_descriptor.set_layout(output_dl)
|
||||
.set_feature_map_count(
|
||||
output_shape->dimensions(dnums.output_feature_dimension()))
|
||||
.set_count(output_shape->dimensions(dnums.output_batch_dimension()));
|
||||
output_shape.dimensions(dnums.output_feature_dimension()))
|
||||
.set_count(output_shape.dimensions(dnums.output_batch_dimension()));
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
output_descriptor.set_spatial_dim(
|
||||
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
|
||||
output_shape->dimensions(dnums.output_spatial_dimensions(dim)));
|
||||
output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
|
||||
}
|
||||
|
||||
// Add a singleton dimension in the 1D convolution case.
|
||||
@ -439,6 +436,23 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
return config;
|
||||
}
|
||||
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
const HloCustomCallInstruction* cudnn_call) {
|
||||
GpuConvDescriptor descriptor;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
|
||||
TF_ASSIGN_OR_RETURN(descriptor.backend_config,
|
||||
cudnn_call->backend_config<CudnnConvBackendConfig>());
|
||||
descriptor.operand0_shape = cudnn_call->operand(0)->shape();
|
||||
descriptor.operand1_shape = cudnn_call->operand(1)->shape();
|
||||
descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
|
||||
descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
|
||||
descriptor.window = cudnn_call->window();
|
||||
descriptor.dnums = cudnn_call->convolution_dimension_numbers();
|
||||
descriptor.feature_group_count = cudnn_call->feature_group_count();
|
||||
return GetGpuConvConfig(descriptor, cudnn_call->ToString());
|
||||
}
|
||||
|
||||
StatusOr<GpuConvParams> GetGpuConvParams(
|
||||
const GpuConvConfig& config,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
|
@ -119,9 +119,31 @@ Status RunGpuConv(const GpuConvConfig& conv_config,
|
||||
se::ScratchAllocator* scratch_allocator, se::Stream* stream,
|
||||
RunConvOptions = {});
|
||||
|
||||
// Struct to describe properties of a convolution without being tied to specific
|
||||
// IR. Will be used to help build Convolution thunks from either XLA HLO or
|
||||
// LHLO GPU dialect in MLIR.
|
||||
struct GpuConvDescriptor {
|
||||
CudnnConvKind kind;
|
||||
CudnnConvBackendConfig backend_config;
|
||||
Shape operand0_shape;
|
||||
Shape operand1_shape;
|
||||
Shape result_shape;
|
||||
size_t scratch_size;
|
||||
Window window;
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
int64 feature_group_count;
|
||||
};
|
||||
|
||||
// Returns the convolution configuration given a XLA HLO instruction.
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
const HloCustomCallInstruction* cudnn_call);
|
||||
|
||||
// Returns the convolution configuration given a convolution descriptor `desc`
|
||||
// and a string representation of the convolution instruction `inst_as_string`
|
||||
// (for error reporting).
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(const GpuConvDescriptor& desc,
|
||||
absl::string_view inst_as_string);
|
||||
|
||||
// Implementation details exposed for debugging and log analysis.
|
||||
StatusOr<GpuConvParams> GetGpuConvParams(
|
||||
const GpuConvConfig& conv_config,
|
||||
|
@ -46,6 +46,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
@ -58,6 +59,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
|
||||
@ -966,8 +968,19 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return EmitGemmThunkFromMlir(input);
|
||||
}
|
||||
|
||||
if (IsCustomCallToDnnConvolution(*custom_call)) {
|
||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||
if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
|
||||
mlir::lmhlo_gpu::ConvForwardFusedOp,
|
||||
mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
|
||||
mlir::lmhlo_gpu::ConvBackwardFilterOp,
|
||||
mlir::lmhlo_gpu::ConvBackwardInputOp>(input.op)) {
|
||||
// TODO(jurahul): Window reveral is not yet supported in HLO. Fallback to
|
||||
// HLO based thunk for that case.
|
||||
if (absl::c_any_of(
|
||||
custom_call->window().dimensions(),
|
||||
[](const WindowDimension& dim) { return dim.window_reversal(); })) {
|
||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||
}
|
||||
return EmitConvolutionThunkFromMlir(input);
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -980,6 +993,103 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
custom_call->custom_call_target());
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
|
||||
using mlir::dyn_cast;
|
||||
using mlir::lmhlo_gpu::Activation;
|
||||
using mlir::lmhlo_gpu::ConvBackwardFilterOp;
|
||||
using mlir::lmhlo_gpu::ConvBackwardInputOp;
|
||||
using mlir::lmhlo_gpu::ConvForwardFusedOp;
|
||||
using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
|
||||
using mlir::lmhlo_gpu::ConvForwardOp;
|
||||
|
||||
// Last 2 operands of the convolution operation are the result and scratch.
|
||||
std::vector<BufferAllocation::Slice> operand_slices;
|
||||
int64 num_operands = input.op->getNumOperands();
|
||||
operand_slices.reserve(num_operands - 2);
|
||||
for (mlir::Value operand : input.op->getOperands().drop_back(2)) {
|
||||
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand));
|
||||
operand_slices.push_back(slice);
|
||||
}
|
||||
|
||||
mlir::Value conv_result = input.op->getOperand(num_operands - 2);
|
||||
mlir::Value scratch_result = input.op->getOperand(num_operands - 1);
|
||||
TF_ASSIGN_OR_RETURN(auto conv_result_slice,
|
||||
GetAllocationSliceForMlir(conv_result));
|
||||
TF_ASSIGN_OR_RETURN(auto scratch_slice,
|
||||
GetAllocationSliceForMlir(scratch_result));
|
||||
|
||||
GpuConvDescriptor descriptor;
|
||||
descriptor.operand0_shape = TypeToShape(input.op->getOperand(0).getType());
|
||||
descriptor.operand1_shape = TypeToShape(input.op->getOperand(1).getType());
|
||||
descriptor.result_shape = TypeToShape(conv_result.getType());
|
||||
descriptor.scratch_size = scratch_slice.size();
|
||||
|
||||
auto fill_conv_descriptor = [&](auto op) {
|
||||
descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
|
||||
mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
|
||||
mlir::DenseIntElementsAttr padding = op.padding().getValue();
|
||||
mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
|
||||
mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue();
|
||||
for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
|
||||
WindowDimension* dim = descriptor.window.add_dimensions();
|
||||
// Window size for a convolution is the same as the kernel size.
|
||||
// Kernel size of the convolution is operand1_shape. We need to look at
|
||||
// the convolution dimension numbers kernel spatial dimensions to get
|
||||
// the window size.
|
||||
int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
|
||||
dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
|
||||
dim->set_stride(window_strides.getValue<int64>(index));
|
||||
dim->set_padding_low(padding.getValue<int64>(index));
|
||||
dim->set_padding_high(padding.getValue<int64>(index));
|
||||
dim->set_base_dilation(lhs_dilation.getValue<int64>(index));
|
||||
dim->set_window_dilation(rhs_dilation.getValue<int64>(index));
|
||||
}
|
||||
descriptor.feature_group_count = op.feature_group_count();
|
||||
descriptor.backend_config.set_algorithm(
|
||||
op.backend_config().algorithm().getInt());
|
||||
descriptor.backend_config.set_tensor_ops_enabled(
|
||||
op.backend_config().tensor_ops_enabled().getValue());
|
||||
descriptor.backend_config.set_conv_result_scale(
|
||||
op.result_scale().convertToDouble());
|
||||
};
|
||||
|
||||
auto set_activation_mode = [&](auto op) -> Status {
|
||||
TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
|
||||
ConvertConvActivationMode(op.activation_mode()));
|
||||
descriptor.backend_config.set_activation_mode(
|
||||
static_cast<int64>(activation_mode));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
if (auto op = dyn_cast<ConvForwardOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kForward;
|
||||
fill_conv_descriptor(op);
|
||||
} else if (auto op = dyn_cast<ConvBackwardInputOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kBackwardInput;
|
||||
fill_conv_descriptor(op);
|
||||
} else if (auto op = dyn_cast<ConvBackwardFilterOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kBackwardFilter;
|
||||
fill_conv_descriptor(op);
|
||||
} else if (auto op = dyn_cast<ConvForwardFusedOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kForwardActivation;
|
||||
fill_conv_descriptor(op);
|
||||
TF_RETURN_IF_ERROR(set_activation_mode(op));
|
||||
} else if (auto op = dyn_cast<ConvForwardFusedSideInputOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kForwardActivation;
|
||||
fill_conv_descriptor(op);
|
||||
TF_RETURN_IF_ERROR(set_activation_mode(op));
|
||||
descriptor.backend_config.set_side_input_scale(
|
||||
op.side_input_scale().convertToDouble());
|
||||
} else {
|
||||
return InternalError("Unexpected operation");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
|
||||
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
|
||||
input.thunk_info, std::move(config), std::move(operand_slices),
|
||||
conv_result_slice, scratch_slice));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) {
|
||||
auto build_gemm_config = [](auto op) {
|
||||
GpuGemmConfig config;
|
||||
|
@ -168,6 +168,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleConditional(HloInstruction* conditional) override;
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||
Status EmitConvolutionThunkFromMlir(MlirEmitterInput input);
|
||||
Status EmitGemmThunkFromMlir(MlirEmitterInput input);
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||
Status EmitCholeskyThunkFromMlir(MlirEmitterInput input);
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
@ -122,8 +123,8 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
|
||||
|
||||
StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
|
||||
XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
|
||||
const Layout& input, const Layout& filter,
|
||||
const Layout& output) {
|
||||
const Shape& input, const Shape& filter,
|
||||
const Shape& output) {
|
||||
Layout nchw_input, nchw_filter, nchw_output;
|
||||
std::tie(nchw_input, nchw_filter, nchw_output) =
|
||||
StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
|
||||
@ -139,35 +140,35 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
DataLayout input_layout;
|
||||
if (LayoutUtil::Equal(input, nchw_input)) {
|
||||
if (ShapeUtil::ShapeIsComatibleWithLayout(input, nchw_input)) {
|
||||
input_layout = DataLayout::kBatchDepthYX;
|
||||
} else if (LayoutUtil::Equal(input, nhwc_input)) {
|
||||
} else if (ShapeUtil::ShapeIsComatibleWithLayout(input, nhwc_input)) {
|
||||
input_layout = DataLayout::kBatchYXDepth;
|
||||
} else {
|
||||
return InternalError("Invalid input layout %s for conv with dnums %s",
|
||||
LayoutUtil::HumanString(input),
|
||||
return InternalError("Invalid input shape %s for conv with dnums %s",
|
||||
ShapeUtil::HumanStringWithLayout(input),
|
||||
ConvolutionDimensionNumbersToString(dnums));
|
||||
}
|
||||
|
||||
FilterLayout filter_layout;
|
||||
if (LayoutUtil::Equal(filter, nchw_filter)) {
|
||||
if (ShapeUtil::ShapeIsComatibleWithLayout(filter, nchw_filter)) {
|
||||
filter_layout = FilterLayout::kOutputInputYX;
|
||||
} else if (LayoutUtil::Equal(filter, nhwc_filter)) {
|
||||
} else if (ShapeUtil::ShapeIsComatibleWithLayout(filter, nhwc_filter)) {
|
||||
filter_layout = FilterLayout::kOutputYXInput;
|
||||
} else {
|
||||
return InternalError("Invalid filter layout %s for conv with dnums %s",
|
||||
LayoutUtil::HumanString(filter),
|
||||
return InternalError("Invalid filter shape %s for conv with dnums %s",
|
||||
ShapeUtil::HumanStringWithLayout(filter),
|
||||
ConvolutionDimensionNumbersToString(dnums));
|
||||
}
|
||||
|
||||
DataLayout output_layout;
|
||||
if (LayoutUtil::Equal(output, nchw_output)) {
|
||||
if (ShapeUtil::ShapeIsComatibleWithLayout(output, nchw_output)) {
|
||||
output_layout = DataLayout::kBatchDepthYX;
|
||||
} else if (LayoutUtil::Equal(output, nhwc_output)) {
|
||||
} else if (ShapeUtil::ShapeIsComatibleWithLayout(output, nhwc_output)) {
|
||||
output_layout = DataLayout::kBatchYXDepth;
|
||||
} else {
|
||||
return InternalError("Invalid output layout %s for conv with dnums %s",
|
||||
LayoutUtil::HumanString(output),
|
||||
return InternalError("Invalid output shape %s for conv with dnums %s",
|
||||
ShapeUtil::HumanStringWithLayout(output),
|
||||
ConvolutionDimensionNumbersToString(dnums));
|
||||
}
|
||||
|
||||
|
@ -49,8 +49,8 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
|
||||
StatusOr<
|
||||
std::tuple<se::dnn::DataLayout, se::dnn::FilterLayout, se::dnn::DataLayout>>
|
||||
XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
|
||||
const Layout& input, const Layout& filter,
|
||||
const Layout& output);
|
||||
const Shape& input, const Shape& filter,
|
||||
const Shape& output);
|
||||
|
||||
// Generates and returns a unique lock per each provided executor.
|
||||
// Guarantees that blocks of code both holding a lock for the same provided
|
||||
|
@ -527,27 +527,5 @@ IrArray IrArray::CastToShape(const Shape& new_shape,
|
||||
return new_irarray;
|
||||
}
|
||||
|
||||
bool IrArray::Index::ShapeIsCompatible(const Shape& a, const Shape& b) {
|
||||
// Compute strides for two sides of the comparison. Sometimes different shapes
|
||||
// give the same strides:
|
||||
// [10, 20, 30, 1]{3,2,1,0} vs [10, 20, 1, 30]{3,2,1,0}
|
||||
// which should be considered compatible.
|
||||
const auto get_strides = [](const Shape& shape) {
|
||||
int rank = shape.dimensions().size();
|
||||
int64 stride = 1;
|
||||
std::vector<int64> strides;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
auto dim = shape.dimensions(shape.layout().minor_to_major(i));
|
||||
if (dim != 1) {
|
||||
stride *= dim;
|
||||
strides.push_back(stride);
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
};
|
||||
|
||||
return get_strides(a) == get_strides(b);
|
||||
}
|
||||
|
||||
} // namespace llvm_ir
|
||||
} // namespace xla
|
||||
|
@ -110,12 +110,13 @@ class IrArray {
|
||||
|
||||
bool LinearValidOnShape(const Shape& a) const;
|
||||
|
||||
static bool ShapeIsCompatible(const Shape& a, const Shape& b);
|
||||
static bool ShapeIsCompatible(const Shape& a, const Shape& b) {
|
||||
return ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
|
||||
ShapeUtil::ReshapeIsBitcast(a, b);
|
||||
}
|
||||
|
||||
bool ShapeIsCompatible(const Shape& a) const {
|
||||
return ShapeIsCompatible(
|
||||
a, ShapeUtil::MakeShapeWithLayout(a.element_type(), dims_,
|
||||
layout_.minor_to_major()));
|
||||
return ShapeUtil::ShapeIsComatibleWithLayout(a, layout_);
|
||||
}
|
||||
|
||||
// Given that "this" is the target index of a reshape from `input_shape`
|
||||
|
@ -1394,6 +1394,16 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
||||
check_input_unit_indices(output_shape, input_shape);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::ShapeIsComatibleWithLayout(const Shape& shape,
|
||||
const Layout& layout) {
|
||||
if (shape.rank() != layout.minor_to_major_size()) {
|
||||
return false;
|
||||
}
|
||||
Shape new_shape = shape;
|
||||
*new_shape.mutable_layout() = layout;
|
||||
return ShapeUtil::ReshapeIsBitcast(shape, new_shape);
|
||||
}
|
||||
|
||||
/* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
|
||||
const Shape& input_shape, const Shape& output_shape) {
|
||||
CHECK(input_shape.IsArray());
|
||||
|
@ -645,6 +645,15 @@ class ShapeUtil {
|
||||
static bool ReshapeIsBitcast(const Shape& input_shape,
|
||||
const Shape& output_shape);
|
||||
|
||||
// Returns true if changing the layout of the given shape from its existing
|
||||
// one to 'layout' does not change the underlying layout of the elements
|
||||
// in physical memory. As an example the shape 'f16[1,1,1,8]{1,2,3,0}' is
|
||||
// compatible with layout '{2,1,3,0}', since the shape 'f16[1,1,1,8]{2,1,3,0}'
|
||||
// and the shape f16[1,1,1,8]{1,2,3,0} has the same layout of elements in
|
||||
// memory.
|
||||
static bool ShapeIsComatibleWithLayout(const Shape& shape,
|
||||
const Layout& layout);
|
||||
|
||||
// Find a physical layout for 'output_shape' such that
|
||||
// ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
|
||||
// true (where 'output_shape_with_layout' is 'output_shape' with the found
|
||||
|
Loading…
Reference in New Issue
Block a user