[TF:MLIR] Use LayoutSensitiveInterface to assign optimal data format.
If no data format is explicitly forced by layout assignment pass options, use LayoutSensitiveInterface::GetOptimalLayout to select data format for all layout sensitive ops. PiperOrigin-RevId: 300796065 Change-Id: If5a18a07e3e5e6e6dfc689718289f8ebc057baf4
This commit is contained in:
parent
9d297ebabc
commit
0ad8a52c4d
@ -293,6 +293,51 @@ static LogicalResult VerifyTypesCompatibility(
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper functions detect device capabilities from RuntimeDevices.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
|
||||
using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
|
||||
|
||||
bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) {
|
||||
return device.type == ::tensorflow::DEVICE_GPU;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Returns true if at least one GPU device is available at runtime.
|
||||
bool CanUseGpuDevice(const RuntimeDevices &devices) {
|
||||
return llvm::any_of(devices.device_names(), IsGpuDevice);
|
||||
}
|
||||
|
||||
// Returns true if all of the GPUs available at runtime support TensorCores
|
||||
// (NVIDIA compute capability >= 7.0).
|
||||
bool CanUseTensorCores(const RuntimeDevices &devices) {
|
||||
auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) {
|
||||
auto md = devices.GetGpuDeviceMetadata(device);
|
||||
return md ? md->cc_major().getInt() >= 7 : false;
|
||||
};
|
||||
return llvm::all_of(
|
||||
llvm::make_filter_range(devices.device_names(), IsGpuDevice),
|
||||
has_tensor_cores);
|
||||
}
|
||||
|
||||
// Returns true if operation does not have explicit device placement that would
|
||||
// prevent it from running on GPU device.
|
||||
bool CanUseGpuDevice(Operation *op) {
|
||||
auto device_attr = op->getAttrOfType<StringAttr>("device");
|
||||
if (!device_attr || device_attr.getValue().empty()) return true;
|
||||
|
||||
DeviceNameUtils::ParsedName device;
|
||||
if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device))
|
||||
return false;
|
||||
|
||||
// We can't use GPU if operation explicitly placed on non-GPU device.
|
||||
return !device.has_type || device.type == ::tensorflow::DEVICE_GPU;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF op helper functions to work with layout transformation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1002,8 +1047,56 @@ LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
|
||||
}
|
||||
|
||||
StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) {
|
||||
// TODO(ezhulenev): Implement optimal layout selection.
|
||||
return "";
|
||||
// Keep current data format if no GPUs are available or if explicit placement
|
||||
// does not allow to use GPU for this operation.
|
||||
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
|
||||
return data_format();
|
||||
|
||||
// Input must be a tensor.
|
||||
auto input_ty = input().getType().dyn_cast<TensorType>();
|
||||
if (!input_ty) return data_format();
|
||||
|
||||
// For f16 data type on devices with Tensor Cores support NHWC data format
|
||||
// is up to ~2x faster.
|
||||
const bool is_f16 = input_ty.getElementType().isF16();
|
||||
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
|
||||
|
||||
// For f32/f16 data type decision depends on the filter size in spatial
|
||||
// dimensions, for other data types we keep current data format.
|
||||
if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16())
|
||||
return data_format();
|
||||
|
||||
// Keep current data format if filter rank is unknown or not equal to 4.
|
||||
auto filter_ty = filter().getType().dyn_cast<RankedTensorType>();
|
||||
if (!filter_ty || filter_ty.getRank() != 4) return data_format();
|
||||
|
||||
const int64_t d0 = filter_ty.getDimSize(0);
|
||||
const int64_t d1 = filter_ty.getDimSize(1);
|
||||
|
||||
auto all_ones = [](ArrayAttr arr) -> bool {
|
||||
return llvm::all_of(arr, [](Attribute attr) -> bool {
|
||||
return attr.cast<IntegerAttr>().getInt() == 1;
|
||||
});
|
||||
};
|
||||
|
||||
// Convolutions with 1x1 filter and with strides and dilations all ones, can
|
||||
// be computed as a GEMM in NHWC data format, and can be up to ~2x times
|
||||
// faster than convolution in NCHW.
|
||||
const bool one_by_one = d0 == 1 && d1 == 1;
|
||||
const bool trivial_strides = all_ones(strides());
|
||||
const bool trivial_dilations = all_ones(dilations());
|
||||
|
||||
// TODO(ezhulenev): This might lead to excessive transposes in the final IR,
|
||||
// if the ratio of 1x1 convolutions to regular convolutions is close to 1:1.
|
||||
// Also FusedBatchNorm in training mode prefers NCHW data format. Check if all
|
||||
// users can efficiently use NHWC data format?
|
||||
if (one_by_one && trivial_strides && trivial_dilations) {
|
||||
return "NHWC";
|
||||
}
|
||||
|
||||
// If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because
|
||||
// it's the fastest option on NVIDIA GPUs with cuDNN library support.
|
||||
return "NCHW";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -0,0 +1,25 @@
|
||||
// RUN: tf-opt %s -tf-layout-assignment -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
module attributes {
|
||||
tf.devices = {"/device:GPU:0" = {cc_major = 6 : i32, cc_minor = 0 : i32}}
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: func @transposeConv2D_3x3_f16
|
||||
func @transposeConv2D_3x3_f16(%input: tensor<1x28x28x64xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x28x28x64xf16> {
|
||||
// cuDNN prefers NCHW data format for spatial convolutions in f16 before
|
||||
// compute capability 7.0 (NVIDIA Tensor Cores).
|
||||
|
||||
// CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
%0 = "tf.Conv2D"(%input, %filter)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
padding = "VALID",
|
||||
strides = [1, 1, 1, 1]
|
||||
} : (tensor<1x28x28x64xf16>, tensor<3x3x64x64xf16>)
|
||||
-> tensor<1x28x28x64xf16>
|
||||
|
||||
return %0 : tensor<1x28x28x64xf16>
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
// RUN: tf-opt %s -tf-layout-assignment -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
module attributes {
|
||||
tf.devices = {"/device:GPU:0" = {cc_major = 7 : i32, cc_minor = 0 : i32}}
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: func @transposeConv2D_3x3_f32
|
||||
func @transposeConv2D_3x3_f32(%input: tensor<1x28x28x64xf32>, %filter: tensor<3x3x64x64xf32>) -> tensor<1x28x28x64xf32> {
|
||||
// cuDNN prefers NCHW data format for spatial convolutions.
|
||||
// CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
%0 = "tf.Conv2D"(%input, %filter)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
padding = "VALID",
|
||||
strides = [1, 1, 1, 1]
|
||||
} : (tensor<1x28x28x64xf32>, tensor<3x3x64x64xf32>)
|
||||
-> tensor<1x28x28x64xf32>
|
||||
|
||||
return %0 : tensor<1x28x28x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeConv2D_1x1_f32
|
||||
func @transposeConv2D_1x1_f32(%input: tensor<1x64x28x28xf32>, %filter: tensor<1x1x64x64xf32>) -> tensor<1x64x28x28xf32> {
|
||||
// 1x1 convolution can be computed as a GEMM in NHWC data format.
|
||||
// CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
%0 = "tf.Conv2D"(%input, %filter)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
padding = "VALID",
|
||||
strides = [1, 1, 1, 1]
|
||||
} : (tensor<1x64x28x28xf32>, tensor<1x1x64x64xf32>)
|
||||
-> tensor<1x64x28x28xf32>
|
||||
|
||||
// Striding in spatial dimensions does not allow to use GEMM.
|
||||
// CHECK: "tf.Conv2D"(%arg0, %arg1)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
%1 = "tf.Conv2D"(%input, %filter)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
padding = "VALID",
|
||||
strides = [1, 1, 2, 2]
|
||||
} : (tensor<1x64x28x28xf32>, tensor<1x1x64x64xf32>)
|
||||
-> tensor<1x64x14x14xf32>
|
||||
|
||||
return %0 : tensor<1x64x28x28xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeConv2D_3x3_f16
|
||||
func @transposeConv2D_3x3_f16(%input: tensor<1x64x28x28xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x64x28x28xf16> {
|
||||
// To use Tensor Cores for f16 data type, input must be in NHWC data format.
|
||||
// CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
%0 = "tf.Conv2D"(%input, %filter)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
padding = "VALID",
|
||||
strides = [1, 1, 1, 1]
|
||||
} : (tensor<1x64x28x28xf16>, tensor<3x3x64x64xf16>)
|
||||
-> tensor<1x64x28x28xf16>
|
||||
|
||||
return %0 : tensor<1x64x28x28xf16>
|
||||
}
|
||||
|
||||
}
|
@ -17,12 +17,15 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-layout-optimization"
|
||||
|
||||
@ -90,22 +93,32 @@ Permutation GetDataFormatPermutation(StringRef from_data_format,
|
||||
void LayoutAssignmentPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
// TODO(ezhulenev): LayoutSensitiveInterface should select the optimal data
|
||||
// layout if there is no explicitly forced data format.
|
||||
if (force_data_format_.empty()) return;
|
||||
// Get runtime devices information from the closest parent module.
|
||||
RuntimeDevices devices;
|
||||
::tensorflow::GetDevicesFromOp(func.getParentOfType<ModuleOp>(), &devices);
|
||||
|
||||
// If there is no runtime device information and data format is not explicitly
|
||||
// forced, there is nothing to do.
|
||||
if (devices.NumDevices() == 0 && force_data_format_.empty()) return;
|
||||
|
||||
func.walk([&](LayoutSensitiveInterface layout_sensitive_interface) {
|
||||
// Get desired op data format.
|
||||
StringRef target_data_format = force_data_format_;
|
||||
if (target_data_format.empty()) {
|
||||
target_data_format = layout_sensitive_interface.GetOptimalLayout(devices);
|
||||
}
|
||||
|
||||
// Skip ops that already use target data format.
|
||||
auto data_format = layout_sensitive_interface.data_format();
|
||||
if (data_format == force_data_format_) return;
|
||||
if (data_format == target_data_format) return;
|
||||
|
||||
// Transpose arguments into the target data format.
|
||||
Permutation args_permutation =
|
||||
GetDataFormatPermutation(data_format, force_data_format_);
|
||||
GetDataFormatPermutation(data_format, target_data_format);
|
||||
|
||||
// Transpose results back to the original data format.
|
||||
Permutation res_permutation =
|
||||
GetDataFormatPermutation(force_data_format_, data_format);
|
||||
GetDataFormatPermutation(target_data_format, data_format);
|
||||
|
||||
if (args_permutation.empty() || res_permutation.empty()) return;
|
||||
|
||||
@ -119,7 +132,7 @@ void LayoutAssignmentPass::runOnFunction() {
|
||||
};
|
||||
|
||||
// Change operation data format.
|
||||
if (failed(layout_sensitive_interface.UpdateDataFormat(force_data_format_)))
|
||||
if (failed(layout_sensitive_interface.UpdateDataFormat(target_data_format)))
|
||||
return;
|
||||
|
||||
// Permute arguments into the target data format.
|
||||
|
Loading…
x
Reference in New Issue
Block a user