Use a MLIR generated kernel for tanh for the GPU backend (behind build flag).
Preliminary benchmark numbers on Titan V, rank 1 tensor with 262144 elements: Existing GPU kernel: 2.7887899999999983e-06 seconds (average of 200 runs) MLIR generated kernel: 7.366685000000019e-06 seconds (average of 200 runs) PiperOrigin-RevId: 315684120 Change-Id: I16de930fb9422f41ff18a2776ea6672564c67132
This commit is contained in:
parent
66579383a8
commit
fcfc8566c4
@ -18,6 +18,7 @@ load(
|
||||
"tf_opts_nortti_if_lite_protos",
|
||||
)
|
||||
load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
|
||||
load("//tensorflow/core/kernels:build_defs.bzl", "if_mlir_generated_gpu_kernels_enabled")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "if_nccl")
|
||||
@ -126,6 +127,13 @@ config_setting(
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "mlir_generated_gpu_kernels_enabled",
|
||||
values = {
|
||||
"define": "tensorflow_enable_mlir_generated_gpu_kernels=1",
|
||||
},
|
||||
)
|
||||
|
||||
# Public support libraries ----------------------------------------------------
|
||||
|
||||
cc_library(
|
||||
@ -4134,8 +4142,22 @@ tf_kernel_library(
|
||||
|
||||
tf_kernel_library(
|
||||
name = "cwise_op",
|
||||
copts = if_mlir_generated_gpu_kernels_enabled(if_true = ["-DMLIR_GENERATED_GPU_KERNELS_ENABLED=1"]),
|
||||
prefix = "cwise_op",
|
||||
deps = MATH_DEPS,
|
||||
deps = MATH_DEPS + if_mlir_generated_gpu_kernels_enabled(if_true = [":mlir_generated_cwise_op"]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "mlir_generated_cwise_op",
|
||||
gpu_srcs = ["mlir_generated_cwise_op_gpu_tanh.cu.cc"],
|
||||
deps = if_cuda([
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/core/kernels/cubin_headers:tanh_kernels",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
7
tensorflow/core/kernels/build_defs.bzl
Normal file
7
tensorflow/core/kernels/build_defs.bzl
Normal file
@ -0,0 +1,7 @@
|
||||
"""Defines build macros for tensorflow kernels."""
|
||||
|
||||
def if_mlir_generated_gpu_kernels_enabled(if_true, if_false = []):
|
||||
return select({
|
||||
"//tensorflow/core/kernels:mlir_generated_gpu_kernels_enabled": if_true,
|
||||
"//conditions:default": if_false,
|
||||
})
|
@ -20,7 +20,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
DEFINE_UNARY(tanh, Eigen::half);
|
||||
#else
|
||||
DEFINE_UNARY3(tanh, Eigen::half, float, double);
|
||||
#endif
|
||||
DEFINE_SIMPLE_BINARY3(tanh_grad, Eigen::half, float, double);
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
@ -21,8 +21,12 @@ REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
|
||||
complex64, complex128);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
REGISTER(UnaryOp, GPU, "Tanh", functor::tanh, Eigen::half);
|
||||
#else
|
||||
REGISTER3(UnaryOp, GPU, "Tanh", functor::tanh, float, Eigen::half, double);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER2(UnaryOp, SYCL, "Tanh", functor::tanh, float, double);
|
||||
|
128
tensorflow/core/kernels/mlir_generated_cwise_op_gpu_tanh.cu.cc
Normal file
128
tensorflow/core/kernels/mlir_generated_cwise_op_gpu_tanh.cu.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/cubin_headers/tanh_f32_kernel.h"
|
||||
#include "tensorflow/core/kernels/cubin_headers/tanh_f64_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
Status CreateKernel(absl::string_view kernel_name, uint64_t num_args,
|
||||
absl::string_view ptx, absl::Span<const uint8_t> cubin_data,
|
||||
se::StreamExecutor* stream_exec,
|
||||
std::unique_ptr<se::KernelBase>& kernel_base) {
|
||||
se::MultiKernelLoaderSpec loader_spec(num_args);
|
||||
|
||||
if (!cubin_data.empty()) {
|
||||
loader_spec.AddCudaCubinInMemory(
|
||||
reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
|
||||
}
|
||||
|
||||
kernel_base.reset(new se::KernelBase(stream_exec));
|
||||
return stream_exec->GetKernel(loader_spec, kernel_base.get());
|
||||
}
|
||||
|
||||
class MlirGenerateTanhOp : public OpKernel {
|
||||
public:
|
||||
explicit MlirGenerateTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto* stream = ctx->op_device_context()->stream();
|
||||
se::KernelBase* kernel;
|
||||
{
|
||||
std::lock_guard<std::mutex> l(mu_);
|
||||
if (!kernel_) {
|
||||
OP_REQUIRES_OK(ctx, CreateKernel("tanh_kernel", 10, "", cubin_data_,
|
||||
stream->parent(), kernel_));
|
||||
}
|
||||
kernel = kernel_.get();
|
||||
}
|
||||
|
||||
const Tensor& inp = ctx->input(0);
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->forward_input_or_allocate_output({0}, 0, inp.shape(), &out));
|
||||
|
||||
if (inp.NumElements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
se::KernelArgsArray<10> args;
|
||||
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
|
||||
args.add_argument<int64_t>(0);
|
||||
args.add_argument<int64_t>(inp.NumElements());
|
||||
args.add_argument<int64_t>(1);
|
||||
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
|
||||
args.add_argument<int64_t>(0);
|
||||
args.add_argument<int64_t>(inp.NumElements());
|
||||
args.add_argument<int64_t>(1);
|
||||
|
||||
// TODO(b/158649746): Choose block size and thread dim according to the
|
||||
// number of input elements. For now, this supports at most 1024 elements.
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->parent()->Launch(stream, se::ThreadDim(inp.NumElements()),
|
||||
se::BlockDim(1), *kernel, args));
|
||||
}
|
||||
|
||||
protected:
|
||||
absl::Span<const uint8_t> cubin_data_;
|
||||
|
||||
private:
|
||||
std::unique_ptr<se::KernelBase> kernel_;
|
||||
std::mutex mu_;
|
||||
};
|
||||
|
||||
class MlirGenerateTanhF32Op : public MlirGenerateTanhOp {
|
||||
public:
|
||||
explicit MlirGenerateTanhF32Op(OpKernelConstruction* ctx)
|
||||
: MlirGenerateTanhOp(ctx) {
|
||||
cubin_data_ = kTanhF32Kernel;
|
||||
}
|
||||
};
|
||||
|
||||
class MlirGenerateTanhF64Op : public MlirGenerateTanhOp {
|
||||
public:
|
||||
explicit MlirGenerateTanhF64Op(OpKernelConstruction* ctx)
|
||||
: MlirGenerateTanhOp(ctx) {
|
||||
cubin_data_ = kTanhF64Kernel;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"),
|
||||
MlirGenerateTanhF32Op);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<double>("T"),
|
||||
MlirGenerateTanhF64Op);
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user