Add support for f16 generated tanh kernel.

PiperOrigin-RevId: 316487197
Change-Id: Id7daff4dc6264071c9371e9eb31c1f57ac044389
This commit is contained in:
Stephan Herhut 2020-06-15 09:59:35 -07:00 committed by TensorFlower Gardener
parent b09d410f3b
commit cf83ab15b0
4 changed files with 15 additions and 6 deletions

View File

@ -32,6 +32,7 @@ gen_kernel_library(
name = "tanh",
tile_size = "256",
types = [
"f16",
"f32",
"f64",
],

View File

@ -20,9 +20,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if MLIR_GENERATED_GPU_KERNELS_ENABLED
DEFINE_UNARY(tanh, Eigen::half);
#else
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
DEFINE_UNARY3(tanh, Eigen::half, float, double);
#endif
DEFINE_SIMPLE_BINARY3(tanh_grad, Eigen::half, float, double);

View File

@ -21,9 +21,7 @@ REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
complex64, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if MLIR_GENERATED_GPU_KERNELS_ENABLED
REGISTER(UnaryOp, GPU, "Tanh", functor::tanh, Eigen::half);
#else
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
REGISTER3(UnaryOp, GPU, "Tanh", functor::tanh, float, Eigen::half, double);
#endif
#endif

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cubin_headers/tanh_f16_kernel.h"
#include "tensorflow/core/kernels/cubin_headers/tanh_f32_kernel.h"
#include "tensorflow/core/kernels/cubin_headers/tanh_f64_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
@ -102,6 +103,14 @@ class MlirGenerateTanhOp : public OpKernel {
std::mutex mu_;
};
class MlirGenerateTanhF16Op : public MlirGenerateTanhOp {
public:
explicit MlirGenerateTanhF16Op(OpKernelConstruction* ctx)
: MlirGenerateTanhOp(ctx) {
cubin_data_ = kTanhF16Kernel;
}
};
class MlirGenerateTanhF32Op : public MlirGenerateTanhOp {
public:
explicit MlirGenerateTanhF32Op(OpKernelConstruction* ctx)
@ -119,6 +128,9 @@ class MlirGenerateTanhF64Op : public MlirGenerateTanhOp {
};
} // namespace
REGISTER_KERNEL_BUILDER(
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
MlirGenerateTanhF16Op);
REGISTER_KERNEL_BUILDER(
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"),
MlirGenerateTanhF32Op);