Add support for f16 generated tanh kernel.
PiperOrigin-RevId: 316487197 Change-Id: Id7daff4dc6264071c9371e9eb31c1f57ac044389
This commit is contained in:
parent
b09d410f3b
commit
cf83ab15b0
@ -32,6 +32,7 @@ gen_kernel_library(
|
||||
name = "tanh",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
|
@ -20,9 +20,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
DEFINE_UNARY(tanh, Eigen::half);
|
||||
#else
|
||||
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
DEFINE_UNARY3(tanh, Eigen::half, float, double);
|
||||
#endif
|
||||
DEFINE_SIMPLE_BINARY3(tanh_grad, Eigen::half, float, double);
|
||||
|
@ -21,9 +21,7 @@ REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
|
||||
complex64, complex128);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
REGISTER(UnaryOp, GPU, "Tanh", functor::tanh, Eigen::half);
|
||||
#else
|
||||
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
REGISTER3(UnaryOp, GPU, "Tanh", functor::tanh, float, Eigen::half, double);
|
||||
#endif
|
||||
#endif
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/cubin_headers/tanh_f16_kernel.h"
|
||||
#include "tensorflow/core/kernels/cubin_headers/tanh_f32_kernel.h"
|
||||
#include "tensorflow/core/kernels/cubin_headers/tanh_f64_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -102,6 +103,14 @@ class MlirGenerateTanhOp : public OpKernel {
|
||||
std::mutex mu_;
|
||||
};
|
||||
|
||||
class MlirGenerateTanhF16Op : public MlirGenerateTanhOp {
|
||||
public:
|
||||
explicit MlirGenerateTanhF16Op(OpKernelConstruction* ctx)
|
||||
: MlirGenerateTanhOp(ctx) {
|
||||
cubin_data_ = kTanhF16Kernel;
|
||||
}
|
||||
};
|
||||
|
||||
class MlirGenerateTanhF32Op : public MlirGenerateTanhOp {
|
||||
public:
|
||||
explicit MlirGenerateTanhF32Op(OpKernelConstruction* ctx)
|
||||
@ -119,6 +128,9 @@ class MlirGenerateTanhF64Op : public MlirGenerateTanhOp {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
|
||||
MlirGenerateTanhF16Op);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"),
|
||||
MlirGenerateTanhF32Op);
|
||||
|
Loading…
Reference in New Issue
Block a user