Add a kernel generator tool.
The tool takes ops defined in the TF dialect and creates cubin. PiperOrigin-RevId: 310517625 Change-Id: I9cfe0d69eee9bf5c6c72791d109c1da582e72c73
This commit is contained in:
parent
6047d50555
commit
682d67e1fe
49
tensorflow/compiler/mlir/tools/kernel_gen/BUILD
Normal file
49
tensorflow/compiler/mlir/tools/kernel_gen/BUILD
Normal file
@ -0,0 +1,49 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "cubin_creator",
|
||||
srcs = ["cubin_creator.cc"],
|
||||
hdrs = ["cubin_creator.h"],
|
||||
copts = if_cuda(["-DGOOGLE_CUDA=1"]),
|
||||
deps = [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", # buildcleaner: keep
|
||||
"//tensorflow/compiler/mlir/xla:xla_unfuse_batch_norm", # buildcleaner: keep
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering",
|
||||
"//tensorflow/core:cuda_libdevice_path",
|
||||
"//tensorflow/core:lib",
|
||||
] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]),
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf_to_cubin",
|
||||
srcs = ["tf_to_cubin.cc"],
|
||||
deps = [
|
||||
":cubin_creator",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
96
tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl
Normal file
96
tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl
Normal file
@ -0,0 +1,96 @@
|
||||
load("//third_party/gpus/cuda:build_defs.bzl", "cuda_gpu_select_list")
|
||||
|
||||
def _lookup_file(filegroup, path):
|
||||
"""Extracts file at (relative) path in filegroup."""
|
||||
for file in filegroup.files.to_list():
|
||||
if file.path.endswith(path):
|
||||
return file
|
||||
return None
|
||||
|
||||
def _gen_kernel_image_hdr_impl(ctx):
|
||||
if not ctx.attr.gpu_archs:
|
||||
fail("No GPU architecture specified, use --config=cuda or similar.")
|
||||
|
||||
name = ctx.attr.name
|
||||
tile_sizes = ctx.attr.tile_size.replace("x", ",")
|
||||
same_shape = []
|
||||
if ctx.attr.same_shape:
|
||||
same_shape.append("--same_shape=%s" % ctx.attr.same_shape)
|
||||
|
||||
cubins = []
|
||||
images = []
|
||||
for arch in ctx.attr.gpu_archs:
|
||||
filename = "%s.%s.cubin" % (name, arch)
|
||||
cubin = ctx.actions.declare_file(filename)
|
||||
ctx.actions.run(
|
||||
outputs = [cubin],
|
||||
executable = ctx.executable._tool,
|
||||
arguments = same_shape + [
|
||||
"--tile_sizes=%s" % tile_sizes,
|
||||
"--arch=%s" % arch.split("_")[1],
|
||||
"--output=%s" % cubin.path,
|
||||
ctx.attr.op,
|
||||
],
|
||||
mnemonic = "compile",
|
||||
)
|
||||
cubins.append(cubin)
|
||||
images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
|
||||
|
||||
# Generate fatbin file from all cubins.
|
||||
fatbin = ctx.actions.declare_file("%s.fatbin" % name)
|
||||
ctx.actions.run(
|
||||
outputs = [fatbin],
|
||||
inputs = cubins,
|
||||
executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"),
|
||||
arguments = [
|
||||
"--64",
|
||||
"--cmdline=--compile-only",
|
||||
"--link",
|
||||
"--compress-all",
|
||||
"--create=%s" % fatbin.path,
|
||||
] + images,
|
||||
mnemonic = "fatbinary",
|
||||
)
|
||||
|
||||
bin2c = _lookup_file(ctx.attr._cuda_root, "bin/bin2c")
|
||||
ctx.actions.run_shell(
|
||||
outputs = [ctx.outputs.out],
|
||||
inputs = [fatbin],
|
||||
tools = [bin2c],
|
||||
command = "%s --static --const --type=int --name=%s %s 1> %s" %
|
||||
(bin2c.path, ctx.attr.symbol, fatbin.path, ctx.outputs.out.path),
|
||||
mnemonic = "bin2c",
|
||||
)
|
||||
|
||||
_gen_kernel_image_hdr = rule(
|
||||
implementation = _gen_kernel_image_hdr_impl,
|
||||
output_to_genfiles = True,
|
||||
attrs = {
|
||||
"op": attr.string(mandatory = True),
|
||||
"tile_size": attr.string(mandatory = True),
|
||||
"same_shape": attr.string(),
|
||||
"out": attr.output(mandatory = True),
|
||||
"symbol": attr.string(mandatory = True),
|
||||
"gpu_archs": attr.string_list(mandatory = True),
|
||||
"_cuda_root": attr.label(
|
||||
default = Label("//third_party/gpus/cuda:cuda_root"),
|
||||
),
|
||||
"_tool": attr.label(
|
||||
executable = True,
|
||||
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
|
||||
cfg = "host",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def gen_kernel_image_hdr(name, op, tile_size, same_shape = None):
|
||||
"""Generates a C header with fatbin data from a Tensorflow op."""
|
||||
_gen_kernel_image_hdr(
|
||||
name = name,
|
||||
op = op,
|
||||
tile_size = tile_size,
|
||||
same_shape = same_shape,
|
||||
out = "include/tfrt/gpu/ops/tf/%s.h" % name,
|
||||
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
|
||||
gpu_archs = cuda_gpu_select_list("sm_{}"),
|
||||
)
|
264
tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
Normal file
264
tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
Normal file
@ -0,0 +1,264 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===- cubin_creator.cc -----------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file implements the function to compile a TF kernel function to a cubin.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
|
||||
#include "tensorflow/core/platform/cuda_libdevice_path.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
using tensorflow::Status;
|
||||
using xla::InternalError;
|
||||
using xla::StatusOr;
|
||||
|
||||
StatusOr<std::string> GetLibdeviceDir(
|
||||
const xla::HloModuleConfig& hlo_module_config) {
|
||||
for (const string& cuda_root : tensorflow::CandidateCudaRoots(
|
||||
hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) {
|
||||
string libdevice_dir =
|
||||
tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
|
||||
VLOG(2) << "Looking for libdevice at " << libdevice_dir;
|
||||
if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) {
|
||||
VLOG(2) << "Found libdevice dir " << libdevice_dir;
|
||||
return libdevice_dir;
|
||||
}
|
||||
}
|
||||
return InternalError(
|
||||
"Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice");
|
||||
}
|
||||
|
||||
struct MaterializeBroadcastsPass
|
||||
: public mlir::PassWrapper<MaterializeBroadcastsPass, mlir::FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
mlir::ConversionTarget conversionTarget(getContext());
|
||||
mlir::OwningRewritePatternList conversionPatterns;
|
||||
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
|
||||
mlir::xla_hlo::SetupMaterializeBroadcastsLegality(&getContext(),
|
||||
&conversionTarget);
|
||||
mlir::xla_hlo::PopulateMaterializeBroadcastsPatterns(&getContext(),
|
||||
&conversionPatterns);
|
||||
|
||||
if (failed(applyPartialConversion(getFunction(), conversionTarget,
|
||||
conversionPatterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct UnfuseBatchNormPass
|
||||
: public mlir::PassWrapper<UnfuseBatchNormPass, mlir::FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
||||
mlir::applyPatternsAndFoldGreedily(getOperation(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) {
|
||||
return VLOG_IS_ON(1);
|
||||
};
|
||||
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
|
||||
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(false));
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
absl::make_unique<MaterializeBroadcastsPass>());
|
||||
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
|
||||
pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::xla_lhlo::createLhloCopyRemovalPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering TF to LHLO failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
struct PropagateStaticKnowledge
|
||||
: public mlir::PassWrapper<PropagateStaticKnowledge,
|
||||
mlir::OperationPass<mlir::LLVM::LLVMFuncOp>> {
|
||||
explicit PropagateStaticKnowledge(mlir::FunctionType type,
|
||||
llvm::ArrayRef<unsigned> same_shape_)
|
||||
: func_type(type), same_shape(same_shape_) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
// We know due to tensorflow ABI that the offset is always 0 and that the
|
||||
// innermost stride is always 1. To make this visible to the compiler,
|
||||
// we insert constants into the code and replace usages accordingly.
|
||||
// We do not change the signature so that we keep a somewhat stable ABI
|
||||
// that is easy to undertand by tools.
|
||||
mlir::LLVM::LLVMFuncOp func = getOperation();
|
||||
mlir::OpBuilder b(func.getBody());
|
||||
auto index_type = func.getArgument(3).getType();
|
||||
mlir::Value one = b.create<mlir::LLVM::ConstantOp>(
|
||||
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1));
|
||||
mlir::Value zero = b.create<mlir::LLVM::ConstantOp>(
|
||||
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0));
|
||||
unsigned arg_pos = 0;
|
||||
std::vector<unsigned> positions;
|
||||
for (mlir::Type arg_type : func_type.getInputs()) {
|
||||
positions.push_back(arg_pos);
|
||||
func.getArgument(arg_pos + 2).replaceAllUsesWith(zero);
|
||||
arg_pos += 3 + arg_type.cast<mlir::ShapedType>().getRank() * 2;
|
||||
func.getArgument(arg_pos - 1).replaceAllUsesWith(one);
|
||||
}
|
||||
|
||||
// If we have knowledge that some arguments have the same shape, we
|
||||
// can use that here. Simply replace usages of the shape parameters within
|
||||
// the function body to a single shape parameter.
|
||||
if (!same_shape.empty()) {
|
||||
int first = same_shape.front();
|
||||
int first_offset = positions.at(first);
|
||||
mlir::ShapedType first_type =
|
||||
func_type.getInput(first).cast<mlir::ShapedType>();
|
||||
unsigned rank = first_type.getRank();
|
||||
for (int same : same_shape.drop_front(1)) {
|
||||
unsigned same_offset = positions.at(same);
|
||||
auto same_type = func_type.getInput(same).cast<mlir::ShapedType>();
|
||||
if (same_type.getRank() != rank) {
|
||||
func.emitOpError() << "same shape constraints on arguments with "
|
||||
"non-matching shapes: #"
|
||||
<< first << " and #" << same;
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
for (int i = 0; i < 2 * rank; ++i) {
|
||||
// Replace uses for second arg data with first arg.
|
||||
auto same_arg = func.getArgument(same_offset + 3 + i);
|
||||
auto first_arg = func.getArgument(first_offset + 3 + i);
|
||||
same_arg.replaceAllUsesWith(first_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mlir::FunctionType func_type;
|
||||
llvm::ArrayRef<unsigned> same_shape;
|
||||
};
|
||||
|
||||
Status PropagateStaticShapeKnowledgeToKernel(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<unsigned> same_shape) {
|
||||
// Grab the original signature from the single function.
|
||||
auto func = *module.getBody()->op_begin<mlir::FuncOp>();
|
||||
|
||||
mlir::PassManager pm(module.getContext());
|
||||
auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) {
|
||||
return VLOG_IS_ON(1);
|
||||
};
|
||||
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
|
||||
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
|
||||
auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>();
|
||||
kernel_pm.addNestedPass<mlir::LLVM::LLVMFuncOp>(
|
||||
absl::make_unique<PropagateStaticKnowledge>(func.getType(), same_shape));
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Static knowledge propagation failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::vector<uint8>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
llvm::StringRef tf_code, std::pair<int, int> compute_capability,
|
||||
llvm::ArrayRef<unsigned> tile_sizes, llvm::ArrayRef<unsigned> same_shape,
|
||||
llvm::ArrayRef<unsigned> unroll_factors) {
|
||||
mlir::MLIRContext context;
|
||||
context.allowUnregisteredDialects(); // TODO(b/152572127)
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
|
||||
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors,
|
||||
/*collapseParallelLoops=*/false));
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
|
||||
|
||||
mlir::OwningModuleRef kernel_module =
|
||||
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module);
|
||||
if (!llvmModule) {
|
||||
return InternalError("Could not translate MLIR module to NVVM");
|
||||
}
|
||||
|
||||
llvmModule->setModuleIdentifier("acme");
|
||||
llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout);
|
||||
|
||||
xla::HloModuleConfig config;
|
||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
|
||||
TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx(
|
||||
llvmModule.get(), compute_capability,
|
||||
config, libdevice_dir));
|
||||
VLOG(1) << ptx;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
return tensorflow::se::CompileGpuAsm(
|
||||
std::get<0>(compute_capability), std::get<1>(compute_capability),
|
||||
ptx.c_str(), xla::gpu::PtxOptsFromConfig(config));
|
||||
#else
|
||||
return InternalError(
|
||||
"GOOGLE_CUDA not defined. Did you specify --config=cuda ?");
|
||||
#endif
|
||||
}
|
41
tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
Normal file
41
tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===- cubin_creator.h ------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file declares the function to compile a TF kernel function to a cubin.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace kernel_gen {
|
||||
xla::StatusOr<std::vector<uint8>> GenerateCubinForTfCode(
|
||||
llvm::StringRef tf_code, std::pair<int, int> compute_capability = {7, 5},
|
||||
llvm::ArrayRef<unsigned> tile_sizes = {16, 64},
|
||||
llvm::ArrayRef<unsigned> same_shape = {},
|
||||
llvm::ArrayRef<unsigned> unroll_factors = {});
|
||||
} // namespace kernel_gen
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
118
tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
Normal file
118
tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
Normal file
@ -0,0 +1,118 @@
|
||||
// Copyright 2020 The TensorFlow Runtime Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//===- tf_to_cubin.cc -------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file implements the entry point to compile a tf op to a cubin file.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace {
|
||||
bool ParseStringList(std::string string_list, std::vector<uint32>* result) {
|
||||
result->clear();
|
||||
uint32 item;
|
||||
auto items = absl::StrSplit(string_list, ',');
|
||||
for (const auto& item_str : items) {
|
||||
if (!absl::SimpleAtoi(item_str, &item)) {
|
||||
LOG(ERROR) << "Expected token " << item_str << " to be an integer";
|
||||
return false;
|
||||
}
|
||||
result->push_back(item);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string output_file = "foo.bin";
|
||||
int32 architecture = 50;
|
||||
std::vector<uint32> tile_sizes;
|
||||
std::vector<uint32> unroll_factors;
|
||||
std::vector<uint32> same_shape;
|
||||
|
||||
auto parse_tile_sizes = [&tile_sizes](std::string tile_sizes_str) {
|
||||
if (!ParseStringList(tile_sizes_str, &tile_sizes)) {
|
||||
return false;
|
||||
}
|
||||
// Initialize with the default.
|
||||
if (tile_sizes.empty()) {
|
||||
tile_sizes.push_back(16);
|
||||
tile_sizes.push_back(64);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto parse_unroll_factors =
|
||||
[&unroll_factors](std::string unroll_factors_str) {
|
||||
return ParseStringList(unroll_factors_str, &unroll_factors);
|
||||
};
|
||||
|
||||
auto parse_same_shape = [&same_shape](std::string same_shape_str) {
|
||||
return ParseStringList(same_shape_str, &same_shape);
|
||||
};
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("output", &output_file, "output file"),
|
||||
tensorflow::Flag("arch", &architecture,
|
||||
"target architecture (e.g. 50 for sm_50)"),
|
||||
tensorflow::Flag("tile_sizes", parse_tile_sizes, "16,64",
|
||||
"tile sizes to use"),
|
||||
tensorflow::Flag("unroll_factors", parse_unroll_factors, "",
|
||||
"factors to unroll by, separated by commas"),
|
||||
tensorflow::Flag("same_shape", parse_same_shape, "",
|
||||
"arguments with same shape, separated by commas"),
|
||||
};
|
||||
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
tensorflow::port::InitMain("usage", &argc, &argv);
|
||||
if (!parse_ok) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::pair<int32, int32> compute_capability(architecture / 10,
|
||||
architecture % 10);
|
||||
|
||||
auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
argv[1], compute_capability, tile_sizes, same_shape, unroll_factors);
|
||||
|
||||
if (!cubin.ok()) {
|
||||
LOG(ERROR) << cubin.status();
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<uint8> cubin_data = cubin.ConsumeValueOrDie();
|
||||
|
||||
auto status = tensorflow::WriteStringToFile(
|
||||
tensorflow::Env::Default(), output_file,
|
||||
absl::string_view{reinterpret_cast<char*>(cubin_data.data()),
|
||||
cubin_data.size()});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@ -222,6 +222,7 @@ cc_library(
|
||||
hdrs = if_gpu_is_configured(["asm_compiler.h"]),
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__",
|
||||
"//tensorflow/compiler/xla/service/gpu:__subpackages__",
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:__subpackages__",
|
||||
"//tensorflow/core/kernels:__subpackages__",
|
||||
|
Loading…
x
Reference in New Issue
Block a user