Add MLIR generated abs kernel for the GPU backend
Extract a common base class from the Tanh op that is used for both Tanh and Abs. The abs kernel is also behind the tensorflow_enable_mlir_generated_gpu_kernels flag. PiperOrigin-RevId: 323995229 Change-Id: I58b9e5c7dc0f428983c14bef23aa806ede6cf36c
This commit is contained in:
parent
d2109607f0
commit
a4c706fe09
@ -21,12 +21,15 @@ REGISTER8(UnaryOp, CPU, "Abs", functor::abs, Eigen::half, bfloat16, float,
|
||||
REGISTER2(UnaryOp, CPU, "ComplexAbs", functor::abs, complex64, complex128);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
REGISTER4(UnaryOp, GPU, "Abs", functor::abs, Eigen::half, float, double, int64);
|
||||
#endif
|
||||
REGISTER2(UnaryOp, GPU, "ComplexAbs", functor::abs, complex64, complex128);
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
// registration requires all int32 inputs and outputs to be in host memory.
|
||||
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
REGISTER_KERNEL_BUILDER(Name("Abs")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("x")
|
||||
@ -34,6 +37,7 @@ REGISTER_KERNEL_BUILDER(Name("Abs")
|
||||
.TypeConstraint<int32>("T"),
|
||||
UnaryOp<CPUDevice, functor::abs<int32>>);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER3(UnaryOp, SYCL, "Abs", functor::abs, float, double, int64);
|
||||
|
@ -26,11 +26,18 @@ config_setting(
|
||||
|
||||
tf_kernel_library(
|
||||
name = "cwise_op",
|
||||
gpu_srcs = ["cwise_op_gpu_tanh.cu.cc"],
|
||||
gpu_srcs = [
|
||||
"cwise_op_gpu_base.cu.cc",
|
||||
"cwise_op_gpu_base.cu.h",
|
||||
"cwise_op_gpu_abs.cu.cc",
|
||||
"cwise_op_gpu_tanh.cu.cc",
|
||||
],
|
||||
tags = ["manual"],
|
||||
deps = if_cuda([
|
||||
":abs_kernels",
|
||||
":tanh_kernels",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -57,6 +64,25 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gpu_abs_test",
|
||||
size = "small",
|
||||
srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_abs_test.cc"]),
|
||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/common_runtime:device",
|
||||
"//tensorflow/core/common_runtime:device_factory",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(b/160731748): Re-enable when it works again.
|
||||
# gen_kernel_library(
|
||||
# name = "bias_add",
|
||||
@ -92,3 +118,17 @@ gen_kernel_library(
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "abs",
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
"i32",
|
||||
"i64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/abs_f16_kernel.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/abs_f32_kernel.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/abs_f64_kernel.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/abs_i32_kernel.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/abs_i64_kernel.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
GENERATE_OP_KERNEL_BASE(Abs);
|
||||
} // namespace
|
||||
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, F16, Eigen::half);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, F32, float);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, F64, double);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, I32, int32);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, I64, int64);
|
||||
} // namespace tensorflow
|
129
tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.cc
Normal file
129
tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.cc
Normal file
@ -0,0 +1,129 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
Status CreateKernel(absl::string_view kernel_name, uint64_t num_args,
|
||||
absl::string_view ptx, absl::Span<const uint8_t> cubin_data,
|
||||
se::StreamExecutor* stream_exec,
|
||||
std::unique_ptr<se::KernelBase>& kernel_base) {
|
||||
se::MultiKernelLoaderSpec loader_spec(num_args);
|
||||
|
||||
if (!cubin_data.empty()) {
|
||||
loader_spec.AddCudaCubinInMemory(
|
||||
reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
|
||||
}
|
||||
|
||||
kernel_base.reset(new se::KernelBase(stream_exec));
|
||||
return stream_exec->GetKernel(loader_spec, kernel_base.get());
|
||||
}
|
||||
|
||||
struct LaunchConfig {
|
||||
se::BlockDim blockDim;
|
||||
se::ThreadDim threadDim;
|
||||
};
|
||||
|
||||
LaunchConfig GetLaunchConfiguration(std::vector<uint64> tile_sizes,
|
||||
std::vector<uint64> unrolling_factors,
|
||||
std::vector<uint64> shape) {
|
||||
LaunchConfig result;
|
||||
// Ensure the vectors are length 3 and pad with ones.
|
||||
tile_sizes.resize(3, 1);
|
||||
unrolling_factors.resize(3, 1);
|
||||
shape.resize(3, 1);
|
||||
// The number of threads is given by the tiling size.
|
||||
result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]);
|
||||
// We know that the kernel was generated by mapping the three outer-most
|
||||
// dimensions to x,y,z dimensions. So we only need to compute those.
|
||||
std::vector<int> block_dims(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
// Compute the number of grids. We use ceildiv here as we have to allocate
|
||||
// an extra thread/block if the division is not even. The kernel contains
|
||||
// code to handle the boundaries.
|
||||
uint64 number_of_threads = Eigen::divup(shape[i], unrolling_factors[i]);
|
||||
int number_of_grids = Eigen::divup(number_of_threads, tile_sizes[i]);
|
||||
block_dims[i] = number_of_grids;
|
||||
}
|
||||
result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]);
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MlirGeneratedUnaryOp::Compute(OpKernelContext* ctx) {
|
||||
auto* stream = ctx->op_device_context()->stream();
|
||||
se::KernelBase* kernel;
|
||||
{
|
||||
absl::MutexLock l(&mu_);
|
||||
if (!kernel_) {
|
||||
OP_REQUIRES_OK(ctx, CreateKernel(name_, 10, "", cubin_data_,
|
||||
stream->parent(), kernel_));
|
||||
}
|
||||
kernel = kernel_.get();
|
||||
}
|
||||
|
||||
const Tensor& inp = ctx->input(0);
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->forward_input_or_allocate_output({0}, 0, inp.shape(), &out));
|
||||
|
||||
if (inp.NumElements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
se::KernelArgsArray<10> args;
|
||||
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
|
||||
args.add_argument<int64_t>(0);
|
||||
args.add_argument<int64_t>(inp.NumElements());
|
||||
args.add_argument<int64_t>(1);
|
||||
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
|
||||
args.add_argument<int64_t>(0);
|
||||
args.add_argument<int64_t>(inp.NumElements());
|
||||
args.add_argument<int64_t>(1);
|
||||
|
||||
// This has to be aligned with the configuration that was used when building
|
||||
// the kernels. See the corresponding build rules in the `BUILD` file.
|
||||
LaunchConfig config = GetLaunchConfiguration(
|
||||
{256}, {4}, {static_cast<uint64>(inp.NumElements())});
|
||||
OP_REQUIRES_OK(ctx, stream->parent()->Launch(stream, config.threadDim,
|
||||
config.blockDim, *kernel, args));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,77 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_CU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_CU_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class MlirGeneratedUnaryOp : public OpKernel {
|
||||
public:
|
||||
MlirGeneratedUnaryOp(OpKernelConstruction* ctx, std::string name,
|
||||
absl::Span<const uint8_t> cubin_data)
|
||||
: OpKernel(ctx), name_(name), cubin_data_(cubin_data) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
absl::Span<const uint8_t> cubin_data_;
|
||||
std::unique_ptr<se::KernelBase> kernel_;
|
||||
absl::Mutex mu_;
|
||||
};
|
||||
|
||||
#define GENERATE_OP_KERNEL_BASE(kernel_name) \
|
||||
class MlirGenerated##kernel_name##Op : public MlirGeneratedUnaryOp { \
|
||||
public: \
|
||||
MlirGenerated##kernel_name##Op(OpKernelConstruction* ctx, \
|
||||
absl::Span<const uint8_t> cubin_data) \
|
||||
: MlirGeneratedUnaryOp(ctx, \
|
||||
absl::AsciiStrToLower(#kernel_name "_kernel"), \
|
||||
cubin_data) {} \
|
||||
};
|
||||
|
||||
#define GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \
|
||||
class MlirGenerated##kernel_name##data_type##Op \
|
||||
: public MlirGenerated##kernel_name##Op { \
|
||||
public: \
|
||||
explicit MlirGenerated##kernel_name##data_type##Op( \
|
||||
OpKernelConstruction* ctx) \
|
||||
: MlirGenerated##kernel_name \
|
||||
##Op(ctx, k##kernel_name##data_type##Kernel) {} \
|
||||
};
|
||||
|
||||
#define REGISTER_AND_GENERATE_KERNEL(kernel_name, data_type, native_data_type) \
|
||||
namespace { \
|
||||
GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \
|
||||
} \
|
||||
REGISTER_KERNEL_BUILDER(Name(#kernel_name) \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<native_data_type>("T"), \
|
||||
MlirGenerated##kernel_name##data_type##Op);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_CU_H_
|
@ -13,164 +13,24 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/tanh_f16_kernel.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/tanh_f32_kernel.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/tanh_f64_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
Status CreateKernel(absl::string_view kernel_name, uint64_t num_args,
|
||||
absl::string_view ptx, absl::Span<const uint8_t> cubin_data,
|
||||
se::StreamExecutor* stream_exec,
|
||||
std::unique_ptr<se::KernelBase>& kernel_base) {
|
||||
se::MultiKernelLoaderSpec loader_spec(num_args);
|
||||
|
||||
if (!cubin_data.empty()) {
|
||||
loader_spec.AddCudaCubinInMemory(
|
||||
reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
|
||||
}
|
||||
|
||||
kernel_base.reset(new se::KernelBase(stream_exec));
|
||||
return stream_exec->GetKernel(loader_spec, kernel_base.get());
|
||||
}
|
||||
|
||||
struct LaunchConfig {
|
||||
se::BlockDim blockDim;
|
||||
se::ThreadDim threadDim;
|
||||
};
|
||||
|
||||
LaunchConfig GetLaunchConfiguration(std::vector<uint64> tile_sizes,
|
||||
std::vector<uint64> unrolling_factors,
|
||||
std::vector<uint64> shape) {
|
||||
LaunchConfig result;
|
||||
// Ensure the vectors are length 3 and pad with ones.
|
||||
tile_sizes.resize(3, 1);
|
||||
unrolling_factors.resize(3, 1);
|
||||
shape.resize(3, 1);
|
||||
// The number of threads is given by the tiling size.
|
||||
result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]);
|
||||
// We know that the kernel was generated by mapping the three outer-most
|
||||
// dimensions to x,y,z dimensions. So we only need to compute those.
|
||||
std::vector<int> block_dims(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
// Compute the number of grids. We use ceildiv here as we have to allocate
|
||||
// an extra thread/block if the division is not even. The kernel contains
|
||||
// code to handle the boundaries.
|
||||
int number_of_threads =
|
||||
(shape[i] + unrolling_factors[i] - 1) / unrolling_factors[i];
|
||||
int number_of_grids =
|
||||
(number_of_threads + tile_sizes[i] - 1) / tile_sizes[i];
|
||||
block_dims[i] = number_of_grids;
|
||||
}
|
||||
result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]);
|
||||
return result;
|
||||
}
|
||||
|
||||
class MlirGeneratedTanhOp : public OpKernel {
|
||||
public:
|
||||
explicit MlirGeneratedTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto* stream = ctx->op_device_context()->stream();
|
||||
se::KernelBase* kernel;
|
||||
{
|
||||
std::lock_guard<std::mutex> l(mu_);
|
||||
if (!kernel_) {
|
||||
OP_REQUIRES_OK(ctx, CreateKernel("Tanh_kernel", 10, "", cubin_data_,
|
||||
stream->parent(), kernel_));
|
||||
}
|
||||
kernel = kernel_.get();
|
||||
}
|
||||
|
||||
const Tensor& inp = ctx->input(0);
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->forward_input_or_allocate_output({0}, 0, inp.shape(), &out));
|
||||
|
||||
if (inp.NumElements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
se::KernelArgsArray<10> args;
|
||||
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
|
||||
args.add_argument<int64_t>(0);
|
||||
args.add_argument<int64_t>(inp.NumElements());
|
||||
args.add_argument<int64_t>(1);
|
||||
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
|
||||
args.add_device_memory_argument(
|
||||
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
|
||||
args.add_argument<int64_t>(0);
|
||||
args.add_argument<int64_t>(inp.NumElements());
|
||||
args.add_argument<int64_t>(1);
|
||||
|
||||
// This has to be aligned with the configuration that was used when
|
||||
// generating the kernels. See the corresponding build rules in the `BUILD`
|
||||
// file.
|
||||
LaunchConfig config = GetLaunchConfiguration(
|
||||
{256}, {4}, {static_cast<uint64>(inp.NumElements())});
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->parent()->Launch(stream, config.threadDim, config.blockDim,
|
||||
*kernel, args));
|
||||
}
|
||||
|
||||
protected:
|
||||
absl::Span<const uint8_t> cubin_data_;
|
||||
|
||||
private:
|
||||
std::unique_ptr<se::KernelBase> kernel_;
|
||||
std::mutex mu_;
|
||||
};
|
||||
|
||||
class MlirGeneratedTanhF16Op : public MlirGeneratedTanhOp {
|
||||
public:
|
||||
explicit MlirGeneratedTanhF16Op(OpKernelConstruction* ctx)
|
||||
: MlirGeneratedTanhOp(ctx) {
|
||||
cubin_data_ = kTanhF16Kernel;
|
||||
}
|
||||
};
|
||||
|
||||
class MlirGeneratedTanhF32Op : public MlirGeneratedTanhOp {
|
||||
public:
|
||||
explicit MlirGeneratedTanhF32Op(OpKernelConstruction* ctx)
|
||||
: MlirGeneratedTanhOp(ctx) {
|
||||
cubin_data_ = kTanhF32Kernel;
|
||||
}
|
||||
};
|
||||
|
||||
class MlirGeneratedTanhF64Op : public MlirGeneratedTanhOp {
|
||||
public:
|
||||
explicit MlirGeneratedTanhF64Op(OpKernelConstruction* ctx)
|
||||
: MlirGeneratedTanhOp(ctx) {
|
||||
cubin_data_ = kTanhF64Kernel;
|
||||
}
|
||||
};
|
||||
GENERATE_OP_KERNEL_BASE(Tanh);
|
||||
} // namespace
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
|
||||
MlirGeneratedTanhF16Op);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"),
|
||||
MlirGeneratedTanhF32Op);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<double>("T"),
|
||||
MlirGeneratedTanhF64Op);
|
||||
REGISTER_AND_GENERATE_KERNEL(Tanh, F16, Eigen::half)
|
||||
REGISTER_AND_GENERATE_KERNEL(Tanh, F32, float)
|
||||
REGISTER_AND_GENERATE_KERNEL(Tanh, F64, double)
|
||||
} // namespace tensorflow
|
||||
|
95
tensorflow/core/kernels/mlir_generated/gpu_abs_test.cc
Normal file
95
tensorflow/core/kernels/mlir_generated/gpu_abs_test.cc
Normal file
@ -0,0 +1,95 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class GpuAbsTest : public OpsTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::unique_ptr<tensorflow::Device> device_gpu(
|
||||
tensorflow::DeviceFactory::NewDevice("GPU", {},
|
||||
"/job:a/replica:0/task:0"));
|
||||
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
|
||||
}
|
||||
template <typename T, typename RT = T>
|
||||
void RunAbsOp(std::initializer_list<T> input) {
|
||||
TensorShape shape({2, 3});
|
||||
TF_ASSERT_OK(NodeDefBuilder("abs_op", "Abs")
|
||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||
.Attr("T", DataTypeToEnum<T>::v())
|
||||
.Finalize(node_def()));
|
||||
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<T>(shape, input);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
|
||||
std::vector<T> expected;
|
||||
expected.reserve(input.size());
|
||||
for (const T& inp : input) {
|
||||
expected.push_back(static_cast<T>(std::abs(static_cast<RT>(inp))));
|
||||
}
|
||||
test::FillValues<T>(&expected_tensor, expected);
|
||||
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GpuAbsTest, AbsFloat) {
|
||||
RunAbsOp<float>({-std::numeric_limits<float>::infinity(), -0.1f, -0.0f, 0.0f,
|
||||
0.1f, std::numeric_limits<float>::infinity()});
|
||||
}
|
||||
|
||||
TEST_F(GpuAbsTest, AbsDouble) {
|
||||
RunAbsOp<double>({-std::numeric_limits<double>::infinity(), -0.1, -0.0, 0.0,
|
||||
0.1, std::numeric_limits<double>::infinity()});
|
||||
}
|
||||
|
||||
TEST_F(GpuAbsTest, AbsHalf) {
|
||||
RunAbsOp<Eigen::half, float>(
|
||||
{static_cast<Eigen::half>(-std::numeric_limits<double>::infinity()),
|
||||
static_cast<Eigen::half>(-0.1), static_cast<Eigen::half>(-0.0),
|
||||
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(0.1),
|
||||
static_cast<Eigen::half>(std::numeric_limits<double>::infinity())});
|
||||
}
|
||||
|
||||
TEST_F(GpuAbsTest, AbsInt32) {
|
||||
RunAbsOp<int32>({std::numeric_limits<int32>::min(),
|
||||
std::numeric_limits<int32>::min() + 1, -1, 0, 1,
|
||||
std::numeric_limits<int32>::max()});
|
||||
}
|
||||
|
||||
TEST_F(GpuAbsTest, AbsInt64) {
|
||||
RunAbsOp<int64>({std::numeric_limits<int64>::min(),
|
||||
std::numeric_limits<int64>::min() + 1, -1, 0, 1,
|
||||
std::numeric_limits<int64>::max()});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
@ -1,4 +1,4 @@
|
||||
func @Abs(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
func @abs(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Abs"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
|
@ -1,4 +1,4 @@
|
||||
func @Tanh(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
func @tanh(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Tanh"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
|
Loading…
Reference in New Issue
Block a user