From ece423eb03cc5fb1df17710c0f541e393039732d Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 10 Dec 2020 16:02:19 -0800 Subject: [PATCH] [XLA:GPU] Migrate convolution thunk emission to MLIR - Map MLIR LHLO Conv operations to ConvolutionThunk. - Refactor GetGpuConvConfig() to take in a conv descriptor, so that the same function can be used to emit a thunk from XLA HLO and MLIR representations. - Change XlaConvLayoutsToStreamExecutorLayouts() to check shape<->layout compatibility instead of layout equality, since some XLA layouts with unit dimensions are not preserved in XLA HLO Shape -> MLIR -> XLA HLO Shape path. - Note that window reversal is not yet representable in HLO/LHLO, so fall back to XLA HLO based thunk emission for that case. PiperOrigin-RevId: 346885084 Change-Id: Idf8ae057863a46b56feecb145ebaf4152a8b5c5d --- tensorflow/compiler/mlir/xla/BUILD | 19 +++ .../compiler/mlir/xla/attribute_exporter.cc | 87 +++++++++++++ .../compiler/mlir/xla/attribute_exporter.h | 37 ++++++ .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 32 +---- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/gpu_conv_runner.cc | 86 +++++++------ .../xla/service/gpu/gpu_conv_runner.h | 22 ++++ .../xla/service/gpu/ir_emitter_unnested.cc | 114 +++++++++++++++++- .../xla/service/gpu/ir_emitter_unnested.h | 1 + .../xla/service/gpu/stream_executor_util.cc | 29 ++--- .../xla/service/gpu/stream_executor_util.h | 4 +- .../compiler/xla/service/llvm_ir/ir_array.cc | 22 ---- .../compiler/xla/service/llvm_ir/ir_array.h | 9 +- tensorflow/compiler/xla/shape_util.cc | 10 ++ tensorflow/compiler/xla/shape_util.h | 9 ++ 15 files changed, 372 insertions(+), 110 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/attribute_exporter.cc create mode 100644 tensorflow/compiler/mlir/xla/attribute_exporter.h diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index fb611ab453d..63b8765c69c 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -249,6 +249,7 @@ cc_library( ], hdrs = ["mlir_hlo_to_hlo.h"], deps = [ + ":attribute_exporter", ":type_to_shape", "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir/hlo", @@ -337,6 +338,24 @@ cc_library( ], ) +cc_library( + name = "attribute_exporter", + srcs = ["attribute_exporter.cc"], + hdrs = ["attribute_exporter.h"], + deps = [ + "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:lhlo_gpu", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core/platform:types", + "//tensorflow/stream_executor:dnn", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "translate_cl_options", srcs = ["xla_mlir_translate_cl.cc"], diff --git a/tensorflow/compiler/mlir/xla/attribute_exporter.cc b/tensorflow/compiler/mlir/xla/attribute_exporter.cc new file mode 100644 index 00000000000..88296abcd81 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/attribute_exporter.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/attribute_exporter.h" + +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/stream_executor/dnn.h" + +namespace xla { + +ConvolutionDimensionNumbers ConvertConvDimensionNumbers( + mlir::mhlo::ConvDimensionNumbers input) { + ConvolutionDimensionNumbers output; + + output.set_input_batch_dimension( + input.input_batch_dimension().getValue().getSExtValue()); + output.set_input_feature_dimension( + input.input_feature_dimension().getValue().getSExtValue()); + + for (auto v : input.input_spatial_dimensions().getValues()) { + output.add_input_spatial_dimensions(v); + } + + output.set_kernel_input_feature_dimension( + input.kernel_input_feature_dimension().getValue().getSExtValue()); + output.set_kernel_output_feature_dimension( + input.kernel_output_feature_dimension().getValue().getSExtValue()); + + for (auto v : input.kernel_spatial_dimensions().getValues()) { + output.add_kernel_spatial_dimensions(v); + } + + output.set_output_batch_dimension( + input.output_batch_dimension().getValue().getSExtValue()); + output.set_output_feature_dimension( + input.output_feature_dimension().getValue().getSExtValue()); + + for (auto v : input.output_spatial_dimensions().getValues()) { + output.add_output_spatial_dimensions(v); + } + + return output; +} + +StatusOr ConvertConvActivationMode( + llvm::StringRef input) { + llvm::Optional activation = + mlir::lmhlo_gpu::symbolizeActivation(input); + if (!activation) { + return InternalError("Unexpected activation"); + } + + switch (activation.getValue()) { + case mlir::lmhlo_gpu::Activation::None: + return stream_executor::dnn::kNone; + case mlir::lmhlo_gpu::Activation::Sigmoid: + return stream_executor::dnn::kSigmoid; + case mlir::lmhlo_gpu::Activation::Tanh: + return stream_executor::dnn::kTanh; + case mlir::lmhlo_gpu::Activation::Relu: + return stream_executor::dnn::kRelu; + case mlir::lmhlo_gpu::Activation::Relu6: + return stream_executor::dnn::kRelu6; + case mlir::lmhlo_gpu::Activation::ReluX: + return stream_executor::dnn::kReluX; + case mlir::lmhlo_gpu::Activation::BandPass: + return stream_executor::dnn::kBandPass; + default: + return InternalError("Unexpected activation"); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/attribute_exporter.h b/tensorflow/compiler/mlir/xla/attribute_exporter.h new file mode 100644 index 00000000000..c58cff0861c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/attribute_exporter.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/dnn.h" + +namespace xla { + +// Converts the conv dimensions attribute to XLA HLO. +ConvolutionDimensionNumbers ConvertConvDimensionNumbers( + mlir::mhlo::ConvDimensionNumbers input); + +StatusOr ConvertConvActivationMode( + llvm::StringRef input); + +} // namespace xla +#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_ diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 244da2dde1c..2495c341994 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/utils/name_utils.h" +#include "tensorflow/compiler/mlir/xla/attribute_exporter.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" @@ -297,36 +298,7 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers( static xla::ConvolutionDimensionNumbers Convert_dimension_numbers( mlir::mhlo::ConvDimensionNumbers input) { - xla::ConvolutionDimensionNumbers output; - - output.set_input_batch_dimension( - input.input_batch_dimension().getValue().getSExtValue()); - output.set_input_feature_dimension( - input.input_feature_dimension().getValue().getSExtValue()); - - for (int64 v : input.input_spatial_dimensions().getValues()) { - output.add_input_spatial_dimensions(v); - } - - output.set_kernel_input_feature_dimension( - input.kernel_input_feature_dimension().getValue().getSExtValue()); - output.set_kernel_output_feature_dimension( - input.kernel_output_feature_dimension().getValue().getSExtValue()); - - for (int64 v : input.kernel_spatial_dimensions().getValues()) { - output.add_kernel_spatial_dimensions(v); - } - - output.set_output_batch_dimension( - input.output_batch_dimension().getValue().getSExtValue()); - output.set_output_feature_dimension( - input.output_feature_dimension().getValue().getSExtValue()); - - for (int64 v : input.output_spatial_dimensions().getValues()) { - output.add_output_spatial_dimensions(v); - } - - return output; + return xla::ConvertConvDimensionNumbers(input); } xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) { diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index f9bacdd8145..ffa08e22e71 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -270,6 +270,7 @@ cc_library( "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/hlo:lhlo_gpu", + "//tensorflow/compiler/mlir/xla:attribute_exporter", "//tensorflow/compiler/mlir/xla:hlo_module_importer", "//tensorflow/compiler/mlir/xla:hlo_utils", "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index e0ccbad3a01..d70123c042a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" @@ -255,15 +256,17 @@ Status RunGpuConvImpl(const GpuConvParams& params, } // anonymous namespace StatusOr GetGpuConvConfig( - const HloCustomCallInstruction* cudnn_call) { + const GpuConvDescriptor& desc, const absl::string_view inst_as_string) { GpuConvConfig config; - config.input_type = cudnn_call->operand(0)->shape().element_type(); - config.output_type = cudnn_call->shape().tuple_shapes(0).element_type(); + const Shape& operand0_shape = desc.operand0_shape; + const Shape& operand1_shape = desc.operand1_shape; + const Shape& result_shape = desc.result_shape; + const CudnnConvBackendConfig& backend_config = desc.backend_config; - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - cudnn_call->backend_config()); - TF_ASSIGN_OR_RETURN(config.kind, GetCudnnConvKind(cudnn_call)); + config.input_type = operand0_shape.element_type(); + config.output_type = result_shape.element_type(); + config.kind = desc.kind; // The third field is scratch size stored from conv_algorithm_picker // The operand is added to the shape field of the conv instruction @@ -271,13 +274,9 @@ StatusOr GetGpuConvConfig( config.algorithm = se::dnn::AlgorithmConfig( se::dnn::AlgorithmDesc(backend_config.algorithm(), backend_config.tensor_ops_enabled()), - cudnn_call->shape().tuple_shapes(1).dimensions(0)); + desc.scratch_size); config.conv_result_scale = backend_config.conv_result_scale(); - Shape operand0_shape = cudnn_call->operand(0)->shape(); - Shape operand1_shape = cudnn_call->operand(1)->shape(); - Shape result_shape = cudnn_call->shape().tuple_shapes(0); - switch (config.kind) { case CudnnConvKind::kForward: case CudnnConvKind::kForwardActivation: @@ -311,9 +310,8 @@ StatusOr GetGpuConvConfig( fusion.side_input_scale = backend_config.side_input_scale(); } - const Window& window = cudnn_call->window(); - const ConvolutionDimensionNumbers& dnums = - cudnn_call->convolution_dimension_numbers(); + const Window& window = desc.window; + const ConvolutionDimensionNumbers& dnums = desc.dnums; VLOG(3) << "Convolution Algorithm: " << config.algorithm.algorithm()->algo_id(); @@ -330,7 +328,7 @@ StatusOr GetGpuConvConfig( VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; const int num_dimensions = window.dimensions_size(); - CHECK_LE(num_dimensions, 3) << cudnn_call->ToString(); + CHECK_LE(num_dimensions, 3) << inst_as_string; // cuDNN does not support 1D convolutions. We therefore express 1D // convolutions as 2D convolutions where the first spatial dimension is 1. @@ -344,18 +342,18 @@ StatusOr GetGpuConvConfig( window.dimensions_size() > 0 && window.dimensions()[0].window_reversal(); CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()) - << cudnn_call->ToString(); + << inst_as_string; CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()) - << cudnn_call->ToString(); + << inst_as_string; CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()) - << cudnn_call->ToString(); + << inst_as_string; for (const WindowDimension& dim : window.dimensions()) { - CHECK_EQ(dims_reversed, dim.window_reversal()) << cudnn_call->ToString(); - CHECK_EQ(dim.padding_low(), dim.padding_high()) << cudnn_call->ToString(); + CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string; + CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string; CHECK_EQ(dim.base_dilation(), 1) << "cudnn does not support base dilation; it " "must be made explicit with a kPad: " - << cudnn_call->ToString(); + << inst_as_string; } // cuDNN's convolution APIs support the BDYX layout for activations/output and @@ -364,43 +362,42 @@ StatusOr GetGpuConvConfig( FilterLayout filter_dl; DataLayout output_dl; - const Shape* input_shape = &config.input_shape; - const Shape* filter_shape = &config.filter_shape; - const Shape* output_shape = &config.output_shape; + const Shape& input_shape = config.input_shape; + const Shape& filter_shape = config.filter_shape; + const Shape& output_shape = config.output_shape; TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), XlaConvLayoutsToStreamExecutorLayouts( - dnums, input_shape->layout(), filter_shape->layout(), - output_shape->layout())); + dnums, input_shape, filter_shape, output_shape)); BatchDescriptor& input_descriptor = config.input_descriptor; input_descriptor = BatchDescriptor(effective_num_dimensions); input_descriptor.set_layout(input_dl) .set_feature_map_count( - input_shape->dimensions(dnums.input_feature_dimension())) - .set_count(input_shape->dimensions(dnums.input_batch_dimension())); + input_shape.dimensions(dnums.input_feature_dimension())) + .set_count(input_shape.dimensions(dnums.input_batch_dimension())); for (int dim = 0; dim < num_dimensions; ++dim) { // Note that the dimensions are reversed. The same holds below. input_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), - input_shape->dimensions(dnums.input_spatial_dimensions(dim))); + input_shape.dimensions(dnums.input_spatial_dimensions(dim))); } FilterDescriptor& filter_descriptor = config.filter_descriptor; filter_descriptor = FilterDescriptor(effective_num_dimensions); filter_descriptor.set_layout(filter_dl) .set_input_feature_map_count( - filter_shape->dimensions(dnums.kernel_input_feature_dimension())) + filter_shape.dimensions(dnums.kernel_input_feature_dimension())) .set_output_feature_map_count( - filter_shape->dimensions(dnums.kernel_output_feature_dimension())); + filter_shape.dimensions(dnums.kernel_output_feature_dimension())); for (int dim = 0; dim < num_dimensions; ++dim) { filter_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), - filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim))); + filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); } config.conv_desc = ConvolutionDescriptor(effective_num_dimensions); - config.conv_desc.set_group_count(cudnn_call->feature_group_count()); + config.conv_desc.set_group_count(desc.feature_group_count); config.conv_desc.set_convolution_not_crosscorr(dims_reversed); for (int dim = 0; dim < num_dimensions; ++dim) { config.conv_desc @@ -419,12 +416,12 @@ StatusOr GetGpuConvConfig( output_descriptor = BatchDescriptor(effective_num_dimensions); output_descriptor.set_layout(output_dl) .set_feature_map_count( - output_shape->dimensions(dnums.output_feature_dimension())) - .set_count(output_shape->dimensions(dnums.output_batch_dimension())); + output_shape.dimensions(dnums.output_feature_dimension())) + .set_count(output_shape.dimensions(dnums.output_batch_dimension())); for (int dim = 0; dim < num_dimensions; ++dim) { output_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), - output_shape->dimensions(dnums.output_spatial_dimensions(dim))); + output_shape.dimensions(dnums.output_spatial_dimensions(dim))); } // Add a singleton dimension in the 1D convolution case. @@ -439,6 +436,23 @@ StatusOr GetGpuConvConfig( return config; } +StatusOr GetGpuConvConfig( + const HloCustomCallInstruction* cudnn_call) { + GpuConvDescriptor descriptor; + + TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call)); + TF_ASSIGN_OR_RETURN(descriptor.backend_config, + cudnn_call->backend_config()); + descriptor.operand0_shape = cudnn_call->operand(0)->shape(); + descriptor.operand1_shape = cudnn_call->operand(1)->shape(); + descriptor.result_shape = cudnn_call->shape().tuple_shapes(0); + descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0); + descriptor.window = cudnn_call->window(); + descriptor.dnums = cudnn_call->convolution_dimension_numbers(); + descriptor.feature_group_count = cudnn_call->feature_group_count(); + return GetGpuConvConfig(descriptor, cudnn_call->ToString()); +} + StatusOr GetGpuConvParams( const GpuConvConfig& config, absl::Span operand_buffers, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h index 5d27e6d6da7..af63dee867f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h @@ -119,9 +119,31 @@ Status RunGpuConv(const GpuConvConfig& conv_config, se::ScratchAllocator* scratch_allocator, se::Stream* stream, RunConvOptions = {}); +// Struct to describe properties of a convolution without being tied to specific +// IR. Will be used to help build Convolution thunks from either XLA HLO or +// LHLO GPU dialect in MLIR. +struct GpuConvDescriptor { + CudnnConvKind kind; + CudnnConvBackendConfig backend_config; + Shape operand0_shape; + Shape operand1_shape; + Shape result_shape; + size_t scratch_size; + Window window; + ConvolutionDimensionNumbers dnums; + int64 feature_group_count; +}; + +// Returns the convolution configuration given a XLA HLO instruction. StatusOr GetGpuConvConfig( const HloCustomCallInstruction* cudnn_call); +// Returns the convolution configuration given a convolution descriptor `desc` +// and a string representation of the convolution instruction `inst_as_string` +// (for error reporting). +StatusOr GetGpuConvConfig(const GpuConvDescriptor& desc, + absl::string_view inst_as_string); + // Implementation details exposed for debugging and log analysis. StatusOr GetGpuConvParams( const GpuConvConfig& conv_config, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 25d2ba74c29..eed88d54381 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/utils/name_utils.h" +#include "tensorflow/compiler/mlir/xla/attribute_exporter.h" #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" @@ -58,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" @@ -966,8 +968,19 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return EmitGemmThunkFromMlir(input); } - if (IsCustomCallToDnnConvolution(*custom_call)) { - return ThunkEmitter(this).HandleCustomCall(custom_call); + if (mlir::isa(input.op)) { + // TODO(jurahul): Window reveral is not yet supported in HLO. Fallback to + // HLO based thunk for that case. + if (absl::c_any_of( + custom_call->window().dimensions(), + [](const WindowDimension& dim) { return dim.window_reversal(); })) { + return ThunkEmitter(this).HandleCustomCall(custom_call); + } + return EmitConvolutionThunkFromMlir(input); } #if GOOGLE_CUDA @@ -980,6 +993,103 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { custom_call->custom_call_target()); } +Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) { + using mlir::dyn_cast; + using mlir::lmhlo_gpu::Activation; + using mlir::lmhlo_gpu::ConvBackwardFilterOp; + using mlir::lmhlo_gpu::ConvBackwardInputOp; + using mlir::lmhlo_gpu::ConvForwardFusedOp; + using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp; + using mlir::lmhlo_gpu::ConvForwardOp; + + // Last 2 operands of the convolution operation are the result and scratch. + std::vector operand_slices; + int64 num_operands = input.op->getNumOperands(); + operand_slices.reserve(num_operands - 2); + for (mlir::Value operand : input.op->getOperands().drop_back(2)) { + TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand)); + operand_slices.push_back(slice); + } + + mlir::Value conv_result = input.op->getOperand(num_operands - 2); + mlir::Value scratch_result = input.op->getOperand(num_operands - 1); + TF_ASSIGN_OR_RETURN(auto conv_result_slice, + GetAllocationSliceForMlir(conv_result)); + TF_ASSIGN_OR_RETURN(auto scratch_slice, + GetAllocationSliceForMlir(scratch_result)); + + GpuConvDescriptor descriptor; + descriptor.operand0_shape = TypeToShape(input.op->getOperand(0).getType()); + descriptor.operand1_shape = TypeToShape(input.op->getOperand(1).getType()); + descriptor.result_shape = TypeToShape(conv_result.getType()); + descriptor.scratch_size = scratch_slice.size(); + + auto fill_conv_descriptor = [&](auto op) { + descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers()); + mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue(); + mlir::DenseIntElementsAttr padding = op.padding().getValue(); + mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue(); + mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue(); + for (auto index : llvm::seq(0, window_strides.getNumElements())) { + WindowDimension* dim = descriptor.window.add_dimensions(); + // Window size for a convolution is the same as the kernel size. + // Kernel size of the convolution is operand1_shape. We need to look at + // the convolution dimension numbers kernel spatial dimensions to get + // the window size. + int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index); + dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim)); + dim->set_stride(window_strides.getValue(index)); + dim->set_padding_low(padding.getValue(index)); + dim->set_padding_high(padding.getValue(index)); + dim->set_base_dilation(lhs_dilation.getValue(index)); + dim->set_window_dilation(rhs_dilation.getValue(index)); + } + descriptor.feature_group_count = op.feature_group_count(); + descriptor.backend_config.set_algorithm( + op.backend_config().algorithm().getInt()); + descriptor.backend_config.set_tensor_ops_enabled( + op.backend_config().tensor_ops_enabled().getValue()); + descriptor.backend_config.set_conv_result_scale( + op.result_scale().convertToDouble()); + }; + + auto set_activation_mode = [&](auto op) -> Status { + TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode, + ConvertConvActivationMode(op.activation_mode())); + descriptor.backend_config.set_activation_mode( + static_cast(activation_mode)); + return Status::OK(); + }; + + if (auto op = dyn_cast(input.op)) { + descriptor.kind = CudnnConvKind::kForward; + fill_conv_descriptor(op); + } else if (auto op = dyn_cast(input.op)) { + descriptor.kind = CudnnConvKind::kBackwardInput; + fill_conv_descriptor(op); + } else if (auto op = dyn_cast(input.op)) { + descriptor.kind = CudnnConvKind::kBackwardFilter; + fill_conv_descriptor(op); + } else if (auto op = dyn_cast(input.op)) { + descriptor.kind = CudnnConvKind::kForwardActivation; + fill_conv_descriptor(op); + TF_RETURN_IF_ERROR(set_activation_mode(op)); + } else if (auto op = dyn_cast(input.op)) { + descriptor.kind = CudnnConvKind::kForwardActivation; + fill_conv_descriptor(op); + TF_RETURN_IF_ERROR(set_activation_mode(op)); + descriptor.backend_config.set_side_input_scale( + op.side_input_scale().convertToDouble()); + } else { + return InternalError("Unexpected operation"); + } + TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, "")); + AddThunkToThunkSequence(absl::make_unique( + input.thunk_info, std::move(config), std::move(operand_slices), + conv_result_slice, scratch_slice)); + return Status::OK(); +} + Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) { auto build_gemm_config = [](auto op) { GpuGemmConfig config; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b5fc20d09d3..1cd78cc1a86 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -168,6 +168,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleCustomCall(HloInstruction* custom_call) override; + Status EmitConvolutionThunkFromMlir(MlirEmitterInput input); Status EmitGemmThunkFromMlir(MlirEmitterInput input); #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) Status EmitCholeskyThunkFromMlir(MlirEmitterInput input); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 7293b1485fc..5bb2f9d65d9 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -122,8 +123,8 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, StatusOr> XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, - const Layout& input, const Layout& filter, - const Layout& output) { + const Shape& input, const Shape& filter, + const Shape& output) { Layout nchw_input, nchw_filter, nchw_output; std::tie(nchw_input, nchw_filter, nchw_output) = StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX, @@ -139,35 +140,35 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, .ConsumeValueOrDie(); DataLayout input_layout; - if (LayoutUtil::Equal(input, nchw_input)) { + if (ShapeUtil::ShapeIsComatibleWithLayout(input, nchw_input)) { input_layout = DataLayout::kBatchDepthYX; - } else if (LayoutUtil::Equal(input, nhwc_input)) { + } else if (ShapeUtil::ShapeIsComatibleWithLayout(input, nhwc_input)) { input_layout = DataLayout::kBatchYXDepth; } else { - return InternalError("Invalid input layout %s for conv with dnums %s", - LayoutUtil::HumanString(input), + return InternalError("Invalid input shape %s for conv with dnums %s", + ShapeUtil::HumanStringWithLayout(input), ConvolutionDimensionNumbersToString(dnums)); } FilterLayout filter_layout; - if (LayoutUtil::Equal(filter, nchw_filter)) { + if (ShapeUtil::ShapeIsComatibleWithLayout(filter, nchw_filter)) { filter_layout = FilterLayout::kOutputInputYX; - } else if (LayoutUtil::Equal(filter, nhwc_filter)) { + } else if (ShapeUtil::ShapeIsComatibleWithLayout(filter, nhwc_filter)) { filter_layout = FilterLayout::kOutputYXInput; } else { - return InternalError("Invalid filter layout %s for conv with dnums %s", - LayoutUtil::HumanString(filter), + return InternalError("Invalid filter shape %s for conv with dnums %s", + ShapeUtil::HumanStringWithLayout(filter), ConvolutionDimensionNumbersToString(dnums)); } DataLayout output_layout; - if (LayoutUtil::Equal(output, nchw_output)) { + if (ShapeUtil::ShapeIsComatibleWithLayout(output, nchw_output)) { output_layout = DataLayout::kBatchDepthYX; - } else if (LayoutUtil::Equal(output, nhwc_output)) { + } else if (ShapeUtil::ShapeIsComatibleWithLayout(output, nhwc_output)) { output_layout = DataLayout::kBatchYXDepth; } else { - return InternalError("Invalid output layout %s for conv with dnums %s", - LayoutUtil::HumanString(output), + return InternalError("Invalid output shape %s for conv with dnums %s", + ShapeUtil::HumanStringWithLayout(output), ConvolutionDimensionNumbersToString(dnums)); } diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 2b58496e05c..ec6f405ae2e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -49,8 +49,8 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, StatusOr< std::tuple> XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, - const Layout& input, const Layout& filter, - const Layout& output); + const Shape& input, const Shape& filter, + const Shape& output); // Generates and returns a unique lock per each provided executor. // Guarantees that blocks of code both holding a lock for the same provided diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 6da4d08f182..73d430e2c54 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -527,27 +527,5 @@ IrArray IrArray::CastToShape(const Shape& new_shape, return new_irarray; } -bool IrArray::Index::ShapeIsCompatible(const Shape& a, const Shape& b) { - // Compute strides for two sides of the comparison. Sometimes different shapes - // give the same strides: - // [10, 20, 30, 1]{3,2,1,0} vs [10, 20, 1, 30]{3,2,1,0} - // which should be considered compatible. - const auto get_strides = [](const Shape& shape) { - int rank = shape.dimensions().size(); - int64 stride = 1; - std::vector strides; - for (int i = 0; i < rank; i++) { - auto dim = shape.dimensions(shape.layout().minor_to_major(i)); - if (dim != 1) { - stride *= dim; - strides.push_back(stride); - } - } - return strides; - }; - - return get_strides(a) == get_strides(b); -} - } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index dfc49ce3dde..1625869641a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -110,12 +110,13 @@ class IrArray { bool LinearValidOnShape(const Shape& a) const; - static bool ShapeIsCompatible(const Shape& a, const Shape& b); + static bool ShapeIsCompatible(const Shape& a, const Shape& b) { + return ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) && + ShapeUtil::ReshapeIsBitcast(a, b); + } bool ShapeIsCompatible(const Shape& a) const { - return ShapeIsCompatible( - a, ShapeUtil::MakeShapeWithLayout(a.element_type(), dims_, - layout_.minor_to_major())); + return ShapeUtil::ShapeIsComatibleWithLayout(a, layout_); } // Given that "this" is the target index of a reshape from `input_shape` diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index e84a2591707..c55dc1dc617 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1394,6 +1394,16 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified( check_input_unit_indices(output_shape, input_shape); } +/* static */ bool ShapeUtil::ShapeIsComatibleWithLayout(const Shape& shape, + const Layout& layout) { + if (shape.rank() != layout.minor_to_major_size()) { + return false; + } + Shape new_shape = shape; + *new_shape.mutable_layout() = layout; + return ShapeUtil::ReshapeIsBitcast(shape, new_shape); +} + /* static */ absl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { CHECK(input_shape.IsArray()); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index ff47ab6ea80..81bf652a1f9 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -645,6 +645,15 @@ class ShapeUtil { static bool ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape); + // Returns true if changing the layout of the given shape from its existing + // one to 'layout' does not change the underlying layout of the elements + // in physical memory. As an example the shape 'f16[1,1,1,8]{1,2,3,0}' is + // compatible with layout '{2,1,3,0}', since the shape 'f16[1,1,1,8]{2,1,3,0}' + // and the shape f16[1,1,1,8]{1,2,3,0} has the same layout of elements in + // memory. + static bool ShapeIsComatibleWithLayout(const Shape& shape, + const Layout& layout); + // Find a physical layout for 'output_shape' such that // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns // true (where 'output_shape_with_layout' is 'output_shape' with the found