[TF:MLIR] Use LayoutSensitiveInterface to assign optimal data format.

If no data format is explicitly forced by layout assignment pass options, use LayoutSensitiveInterface::GetOptimalLayout to select data format for all layout sensitive ops.

PiperOrigin-RevId: 300796065
Change-Id: If5a18a07e3e5e6e6dfc689718289f8ebc057baf4
This commit is contained in:
Eugene Zhulenev 2020-03-13 11:40:12 -07:00 committed by TensorFlower Gardener
parent 9d297ebabc
commit 0ad8a52c4d
4 changed files with 206 additions and 9 deletions

View File

@ -293,6 +293,51 @@ static LogicalResult VerifyTypesCompatibility(
return success();
}
//===----------------------------------------------------------------------===//
// Helper functions detect device capabilities from RuntimeDevices.
//===----------------------------------------------------------------------===//
namespace {
using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) {
return device.type == ::tensorflow::DEVICE_GPU;
}
} // namespace
// Returns true if at least one GPU device is available at runtime.
bool CanUseGpuDevice(const RuntimeDevices &devices) {
return llvm::any_of(devices.device_names(), IsGpuDevice);
}
// Returns true if all of the GPUs available at runtime support TensorCores
// (NVIDIA compute capability >= 7.0).
bool CanUseTensorCores(const RuntimeDevices &devices) {
auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) {
auto md = devices.GetGpuDeviceMetadata(device);
return md ? md->cc_major().getInt() >= 7 : false;
};
return llvm::all_of(
llvm::make_filter_range(devices.device_names(), IsGpuDevice),
has_tensor_cores);
}
// Returns true if operation does not have explicit device placement that would
// prevent it from running on GPU device.
bool CanUseGpuDevice(Operation *op) {
auto device_attr = op->getAttrOfType<StringAttr>("device");
if (!device_attr || device_attr.getValue().empty()) return true;
DeviceNameUtils::ParsedName device;
if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device))
return false;
// We can't use GPU if operation explicitly placed on non-GPU device.
return !device.has_type || device.type == ::tensorflow::DEVICE_GPU;
}
//===----------------------------------------------------------------------===//
// TF op helper functions to work with layout transformation.
//===----------------------------------------------------------------------===//
@ -1002,8 +1047,56 @@ LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
}
StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) {
// TODO(ezhulenev): Implement optimal layout selection.
return "";
// Keep current data format if no GPUs are available or if explicit placement
// does not allow to use GPU for this operation.
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
return data_format();
// Input must be a tensor.
auto input_ty = input().getType().dyn_cast<TensorType>();
if (!input_ty) return data_format();
// For f16 data type on devices with Tensor Cores support NHWC data format
// is up to ~2x faster.
const bool is_f16 = input_ty.getElementType().isF16();
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
// For f32/f16 data type decision depends on the filter size in spatial
// dimensions, for other data types we keep current data format.
if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16())
return data_format();
// Keep current data format if filter rank is unknown or not equal to 4.
auto filter_ty = filter().getType().dyn_cast<RankedTensorType>();
if (!filter_ty || filter_ty.getRank() != 4) return data_format();
const int64_t d0 = filter_ty.getDimSize(0);
const int64_t d1 = filter_ty.getDimSize(1);
auto all_ones = [](ArrayAttr arr) -> bool {
return llvm::all_of(arr, [](Attribute attr) -> bool {
return attr.cast<IntegerAttr>().getInt() == 1;
});
};
// Convolutions with 1x1 filter and with strides and dilations all ones, can
// be computed as a GEMM in NHWC data format, and can be up to ~2x times
// faster than convolution in NCHW.
const bool one_by_one = d0 == 1 && d1 == 1;
const bool trivial_strides = all_ones(strides());
const bool trivial_dilations = all_ones(dilations());
// TODO(ezhulenev): This might lead to excessive transposes in the final IR,
// if the ratio of 1x1 convolutions to regular convolutions is close to 1:1.
// Also FusedBatchNorm in training mode prefers NCHW data format. Check if all
// users can efficiently use NHWC data format?
if (one_by_one && trivial_strides && trivial_dilations) {
return "NHWC";
}
// If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because
// it's the fastest option on NVIDIA GPUs with cuDNN library support.
return "NCHW";
}
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,25 @@
// RUN: tf-opt %s -tf-layout-assignment -verify-diagnostics | FileCheck %s --dump-input=always
module attributes {
tf.devices = {"/device:GPU:0" = {cc_major = 6 : i32, cc_minor = 0 : i32}}
} {
// CHECK-LABEL: func @transposeConv2D_3x3_f16
func @transposeConv2D_3x3_f16(%input: tensor<1x28x28x64xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x28x28x64xf16> {
// cuDNN prefers NCHW data format for spatial convolutions in f16 before
// compute capability 7.0 (NVIDIA Tensor Cores).
// CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
// CHECK-SAME: data_format = "NCHW"
%0 = "tf.Conv2D"(%input, %filter)
{
data_format = "NHWC",
padding = "VALID",
strides = [1, 1, 1, 1]
} : (tensor<1x28x28x64xf16>, tensor<3x3x64x64xf16>)
-> tensor<1x28x28x64xf16>
return %0 : tensor<1x28x28x64xf16>
}
}

View File

@ -0,0 +1,66 @@
// RUN: tf-opt %s -tf-layout-assignment -verify-diagnostics | FileCheck %s --dump-input=always
module attributes {
tf.devices = {"/device:GPU:0" = {cc_major = 7 : i32, cc_minor = 0 : i32}}
} {
// CHECK-LABEL: func @transposeConv2D_3x3_f32
func @transposeConv2D_3x3_f32(%input: tensor<1x28x28x64xf32>, %filter: tensor<3x3x64x64xf32>) -> tensor<1x28x28x64xf32> {
// cuDNN prefers NCHW data format for spatial convolutions.
// CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
// CHECK-SAME: data_format = "NCHW"
%0 = "tf.Conv2D"(%input, %filter)
{
data_format = "NHWC",
padding = "VALID",
strides = [1, 1, 1, 1]
} : (tensor<1x28x28x64xf32>, tensor<3x3x64x64xf32>)
-> tensor<1x28x28x64xf32>
return %0 : tensor<1x28x28x64xf32>
}
// CHECK-LABEL: func @transposeConv2D_1x1_f32
func @transposeConv2D_1x1_f32(%input: tensor<1x64x28x28xf32>, %filter: tensor<1x1x64x64xf32>) -> tensor<1x64x28x28xf32> {
// 1x1 convolution can be computed as a GEMM in NHWC data format.
// CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
// CHECK-SAME: data_format = "NHWC"
%0 = "tf.Conv2D"(%input, %filter)
{
data_format = "NCHW",
padding = "VALID",
strides = [1, 1, 1, 1]
} : (tensor<1x64x28x28xf32>, tensor<1x1x64x64xf32>)
-> tensor<1x64x28x28xf32>
// Striding in spatial dimensions does not allow to use GEMM.
// CHECK: "tf.Conv2D"(%arg0, %arg1)
// CHECK-SAME: data_format = "NCHW"
%1 = "tf.Conv2D"(%input, %filter)
{
data_format = "NCHW",
padding = "VALID",
strides = [1, 1, 2, 2]
} : (tensor<1x64x28x28xf32>, tensor<1x1x64x64xf32>)
-> tensor<1x64x14x14xf32>
return %0 : tensor<1x64x28x28xf32>
}
// CHECK-LABEL: func @transposeConv2D_3x3_f16
func @transposeConv2D_3x3_f16(%input: tensor<1x64x28x28xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x64x28x28xf16> {
// To use Tensor Cores for f16 data type, input must be in NHWC data format.
// CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1)
// CHECK-SAME: data_format = "NHWC"
%0 = "tf.Conv2D"(%input, %filter)
{
data_format = "NCHW",
padding = "VALID",
strides = [1, 1, 1, 1]
} : (tensor<1x64x28x28xf16>, tensor<3x3x64x64xf16>)
-> tensor<1x64x28x28xf16>
return %0 : tensor<1x64x28x28xf16>
}
}

View File

@ -17,12 +17,15 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#define DEBUG_TYPE "tf-layout-optimization"
@ -90,22 +93,32 @@ Permutation GetDataFormatPermutation(StringRef from_data_format,
void LayoutAssignmentPass::runOnFunction() {
FuncOp func = getFunction();
// TODO(ezhulenev): LayoutSensitiveInterface should select the optimal data
// layout if there is no explicitly forced data format.
if (force_data_format_.empty()) return;
// Get runtime devices information from the closest parent module.
RuntimeDevices devices;
::tensorflow::GetDevicesFromOp(func.getParentOfType<ModuleOp>(), &devices);
// If there is no runtime device information and data format is not explicitly
// forced, there is nothing to do.
if (devices.NumDevices() == 0 && force_data_format_.empty()) return;
func.walk([&](LayoutSensitiveInterface layout_sensitive_interface) {
// Get desired op data format.
StringRef target_data_format = force_data_format_;
if (target_data_format.empty()) {
target_data_format = layout_sensitive_interface.GetOptimalLayout(devices);
}
// Skip ops that already use target data format.
auto data_format = layout_sensitive_interface.data_format();
if (data_format == force_data_format_) return;
if (data_format == target_data_format) return;
// Transpose arguments into the target data format.
Permutation args_permutation =
GetDataFormatPermutation(data_format, force_data_format_);
GetDataFormatPermutation(data_format, target_data_format);
// Transpose results back to the original data format.
Permutation res_permutation =
GetDataFormatPermutation(force_data_format_, data_format);
GetDataFormatPermutation(target_data_format, data_format);
if (args_permutation.empty() || res_permutation.empty()) return;
@ -119,7 +132,7 @@ void LayoutAssignmentPass::runOnFunction() {
};
// Change operation data format.
if (failed(layout_sensitive_interface.UpdateDataFormat(force_data_format_)))
if (failed(layout_sensitive_interface.UpdateDataFormat(target_data_format)))
return;
// Permute arguments into the target data format.