[KERNEL_GEN] Add unranked kernel for Tanh.
PiperOrigin-RevId: 335921952 Change-Id: I0bb4812dce78b73a4a37976936b348dad0ea4942
This commit is contained in:
parent
762cc6192b
commit
fb973b1c7d
@ -41,6 +41,7 @@ filegroup(
|
||||
srcs = if_mlir_unranked_kernels_enabled(
|
||||
[
|
||||
"unranked_op_gpu_abs.cc",
|
||||
"unranked_op_gpu_tanh.cc",
|
||||
"unranked_op_gpu_base.h",
|
||||
"unranked_op_gpu_base.cc",
|
||||
],
|
||||
@ -58,6 +59,7 @@ cc_library(
|
||||
deps = if_mlir_unranked_kernels_enabled(
|
||||
[
|
||||
":abs_unranked_kernels",
|
||||
":tanh_unranked_kernels",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_cuda_runtime_wrappers",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_framework_c_interface",
|
||||
],
|
||||
@ -352,6 +354,7 @@ gen_kernel_library(
|
||||
|
||||
gen_kernel_library(
|
||||
name = "tanh",
|
||||
generate_unranked = True,
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
|
@ -18,10 +18,10 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, f16, Eigen::half, DT_HALF);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, f32, float, DT_FLOAT);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, f64, double, DT_DOUBLE);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, i32, int32, DT_INT32);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, i64, int64, DT_INT64);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, DT_HALF, abs_f16, Eigen::half);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, DT_FLOAT, abs_f32, float);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, DT_DOUBLE, abs_f64, double);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, DT_INT32, abs_i32, int32);
|
||||
REGISTER_AND_GENERATE_KERNEL(Abs, DT_INT64, abs_i64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -72,45 +72,45 @@ Tensor ConvertDescriptorToTensor(
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#define MLIR_FUNCTION(data_type) _mlir_ciface_abs_##data_type
|
||||
#define MLIR_FUNCTION(mlir_func) _mlir_ciface_##mlir_func
|
||||
|
||||
// Generates a class derived from OpKernel with Compute function that converts
|
||||
// input tensors to unranked memref descriptors and calls mlir-generated
|
||||
// unranked kernel. The outputs are converted back to tensors using
|
||||
// MlirTensorBuffer to take ownership of pre-allocated memory.
|
||||
#define REGISTER_AND_GENERATE_KERNEL(kernel_name, type_name, data_type, \
|
||||
tf_data_type) \
|
||||
extern "C" ::UnrankedMemRefType<data_type> MLIR_FUNCTION(type_name)( \
|
||||
tensorflow::OpKernelContext * ctx, \
|
||||
::UnrankedMemRefType<data_type> * arg); \
|
||||
\
|
||||
namespace { \
|
||||
class MlirUnranked##kernel_name##type_name##Op : public OpKernel { \
|
||||
public: \
|
||||
MlirUnranked##kernel_name##type_name##Op(OpKernelConstruction* ctx) \
|
||||
: OpKernel(ctx) {} \
|
||||
\
|
||||
void Compute(OpKernelContext* ctx) override { \
|
||||
const Tensor& input = ctx->input(0); \
|
||||
\
|
||||
auto input_desc = ConvertTensorToDescriptor<data_type>(input); \
|
||||
auto result_desc = MLIR_FUNCTION(type_name)(ctx, &input_desc); \
|
||||
free(input_desc.descriptor); \
|
||||
\
|
||||
tensorflow::AllocatorAttributes attrs; \
|
||||
auto* allocator = ctx->get_allocator(attrs); \
|
||||
\
|
||||
Tensor result_tensor = ConvertDescriptorToTensor<data_type>( \
|
||||
result_desc, tf_data_type, allocator); \
|
||||
free(result_desc.descriptor); \
|
||||
ctx->set_output(0, result_tensor); \
|
||||
} \
|
||||
}; \
|
||||
} \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(#kernel_name).Device(DEVICE_GPU).TypeConstraint<data_type>("T"), \
|
||||
MlirUnranked##kernel_name##type_name##Op);
|
||||
#define REGISTER_AND_GENERATE_KERNEL(tf_op, tf_data_type, mlir_func, \
|
||||
data_type) \
|
||||
extern "C" ::UnrankedMemRefType<data_type> MLIR_FUNCTION(mlir_func)( \
|
||||
tensorflow::OpKernelContext * ctx, \
|
||||
::UnrankedMemRefType<data_type> * arg); \
|
||||
\
|
||||
namespace { \
|
||||
class MlirUnranked##tf_op##mlir_func##Op : public OpKernel { \
|
||||
public: \
|
||||
MlirUnranked##tf_op##mlir_func##Op(OpKernelConstruction* ctx) \
|
||||
: OpKernel(ctx) {} \
|
||||
\
|
||||
void Compute(OpKernelContext* ctx) override { \
|
||||
const Tensor& input = ctx->input(0); \
|
||||
\
|
||||
auto input_desc = ConvertTensorToDescriptor<data_type>(input); \
|
||||
auto result_desc = MLIR_FUNCTION(mlir_func)(ctx, &input_desc); \
|
||||
free(input_desc.descriptor); \
|
||||
\
|
||||
tensorflow::AllocatorAttributes attrs; \
|
||||
auto* allocator = ctx->get_allocator(attrs); \
|
||||
\
|
||||
Tensor result_tensor = ConvertDescriptorToTensor<data_type>( \
|
||||
result_desc, tf_data_type, allocator); \
|
||||
free(result_desc.descriptor); \
|
||||
ctx->set_output(0, result_tensor); \
|
||||
} \
|
||||
}; \
|
||||
} \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(#tf_op).Device(DEVICE_GPU).TypeConstraint<data_type>("T"), \
|
||||
MlirUnranked##tf_op##mlir_func##Op);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -0,0 +1,25 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_AND_GENERATE_KERNEL(Tanh, DT_HALF, tanh_f16, Eigen::half);
|
||||
REGISTER_AND_GENERATE_KERNEL(Tanh, DT_FLOAT, tanh_f32, float);
|
||||
REGISTER_AND_GENERATE_KERNEL(Tanh, DT_DOUBLE, tanh_f64, double);
|
||||
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user