[KERNEL_GEN] Make TF op name and MLIR func prefix coincide in *.mlir.tmpl.

This make the macro in `unranked_op_gpu_base.h` cleaner. This is also a
preparation to simplify it even more.

PiperOrigin-RevId: 336264876
Change-Id: Ieef72ceef373d2b48db4e1deabb753846cb692cf
This commit is contained in:
Alexander Belyaev 2020-10-09 03:44:20 -07:00 committed by TensorFlower Gardener
parent 56ca08345f
commit 8dc7f8a7b7
45 changed files with 89 additions and 91 deletions

View File

@ -43,14 +43,12 @@ class MlirGeneratedUnaryOp : public OpKernel {
absl::Mutex mu_;
};
#define GENERATE_OP_KERNEL_BASE(kernel_name) \
class MlirGenerated##kernel_name##Op : public MlirGeneratedUnaryOp { \
public: \
MlirGenerated##kernel_name##Op(OpKernelConstruction* ctx, \
absl::Span<const uint8_t> cubin_data) \
: MlirGeneratedUnaryOp(ctx, \
absl::AsciiStrToLower(#kernel_name "_kernel"), \
cubin_data) {} \
#define GENERATE_OP_KERNEL_BASE(kernel_name) \
class MlirGenerated##kernel_name##Op : public MlirGeneratedUnaryOp { \
public: \
MlirGenerated##kernel_name##Op(OpKernelConstruction* ctx, \
absl::Span<const uint8_t> cubin_data) \
: MlirGeneratedUnaryOp(ctx, #kernel_name "_kernel", cubin_data) {} \
};
#define GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \

View File

@ -1,4 +1,4 @@
func @abs_elem_type(%arg0: tensor<*xelem_type>)
func @Abs_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Abs"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @acos_elem_type(%arg0: tensor<*xelem_type>)
func @Acos_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Acos"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @acosh_elem_type(%arg0: tensor<*xelem_type>)
func @Acosh_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Acosh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @angle(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> {
func @Angle(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> {
%0 = "tf.Angle"(%arg0) : (tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type>
return %0 : tensor<?xelem_type>
}

View File

@ -1,4 +1,4 @@
func @asin(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
func @Asin(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
%0 = "tf.Asin"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type>
return %0 : tensor<?xelem_type>
}

View File

@ -1,4 +1,4 @@
func @atan(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
func @Atan(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
%0 = "tf.Atan"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type>
return %0 : tensor<?xelem_type>
}

View File

@ -1,4 +1,4 @@
func @bias_add(%arg0: tensor<?x?xelem_type>, %arg1: tensor<?xelem_type>)
func @BiasAdd(%arg0: tensor<?x?xelem_type>, %arg1: tensor<?xelem_type>)
-> tensor<?x?xelem_type> {
%0 = "tf.BiasAdd"(%arg0, %arg1)
: (tensor<?x?xelem_type>, tensor<?xelem_type>) -> tensor<?x?xelem_type>

View File

@ -1,4 +1,4 @@
func @ceil_elem_type(%arg0: tensor<*xelem_type>)
func @Ceil_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Ceil"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @conj(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
func @Conj(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
%0 = "tf.Conj"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type>
return %0 : tensor<?xelem_type>
}

View File

@ -1,4 +1,4 @@
func @cos_elem_type(%arg0: tensor<*xelem_type>)
func @Cos_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Cos"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @cosh_elem_type(%arg0: tensor<*xelem_type>)
func @Cosh_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Cosh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @digamma_elem_type(%arg0: tensor<*xelem_type>)
func @Digamma_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Digamma"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @erf_elem_type(%arg0: tensor<*xelem_type>)
func @Erf_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Erf"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @erfc_elem_type(%arg0: tensor<*xelem_type>)
func @Erfc_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Erfc"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @exp_elem_type(%arg0: tensor<*xelem_type>)
func @Exp_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Exp"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @expm1_elem_type(%arg0: tensor<*xelem_type>)
func @Expm1_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Expm1"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @floor_elem_type(%arg0: tensor<*xelem_type>)
func @Floor_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Floor"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @imag(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> {
func @Imag(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> {
%0 = "tf.Imag"(%arg0) : (tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type>
return %0 : tensor<?xelem_type>
}

View File

@ -1,4 +1,4 @@
func @invert_elem_type(%arg0: tensor<*xelem_type>)
func @Invert_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Invert"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @isfinite_elem_type(%arg0: tensor<*xelem_type>)
func @Isfinite_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.IsFinite"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1>
return %0 : tensor<*xi1>

View File

@ -1,4 +1,4 @@
func @isinf_elem_type(%arg0: tensor<*xelem_type>)
func @Isinf_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.IsInf"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1>
return %0 : tensor<*xi1>

View File

@ -1,4 +1,4 @@
func @isnan_elem_type(%arg0: tensor<*xelem_type>)
func @Isnan_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.IsNan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1>
return %0 : tensor<*xi1>

View File

@ -1,4 +1,4 @@
func @lgamma_elem_type(%arg0: tensor<*xelem_type>)
func @Lgamma_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Lgamma"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @log_elem_type(%arg0: tensor<*xelem_type>)
func @Log_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Log"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @log1p_elem_type(%arg0: tensor<*xelem_type>)
func @Log1p_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Log1p"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @logicalnot_elem_type(%arg0: tensor<*xelem_type>)
func @Logicalnot_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.LogicalNot"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @neg_elem_type(%arg0: tensor<*xelem_type>)
func @Neg_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Neg"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @real(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> {
func @Real(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> {
%0 = "tf.Real"(%arg0) : (tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type>
return %0 : tensor<?xelem_type>
}

View File

@ -1,4 +1,4 @@
func @reciprocal_elem_type(%arg0: tensor<*xelem_type>)
func @Reciprocal_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Reciprocal"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @relu_elem_type(%arg0: tensor<*xelem_type>)
func @Relu_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Relu"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @rint_elem_type(%arg0: tensor<*xelem_type>)
func @Rint_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Rint"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @round_elem_type(%arg0: tensor<*xelem_type>)
func @Round_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Round"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @rsqrt_elem_type(%arg0: tensor<*xelem_type>)
func @Rsqrt_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Rsqrt"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @sigmoid_elem_type(%arg0: tensor<*xelem_type>)
func @Sigmoid_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Sigmoid"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @sign_elem_type(%arg0: tensor<*xelem_type>)
func @Sign_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Sign"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @sin(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
func @Sin(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
%0 = "tf.Sin"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type>
return %0 : tensor<?xelem_type>
}

View File

@ -1,4 +1,4 @@
func @sinh_elem_type(%arg0: tensor<*xelem_type>)
func @Sinh_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Sinh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @sqrt_elem_type(%arg0: tensor<*xelem_type>)
func @Sqrt_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Sqrt"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @square_elem_type(%arg0: tensor<*xelem_type>)
func @Square_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Square"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @tan_elem_type(%arg0: tensor<*xelem_type>)
func @Tan_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Tan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -1,4 +1,4 @@
func @tanh_elem_type(%arg0: tensor<*xelem_type>)
func @Tanh_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Tanh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>

View File

@ -18,10 +18,10 @@ limitations under the License.
namespace tensorflow {
REGISTER_AND_GENERATE_KERNEL(Abs, DT_HALF, abs_f16, Eigen::half);
REGISTER_AND_GENERATE_KERNEL(Abs, DT_FLOAT, abs_f32, float);
REGISTER_AND_GENERATE_KERNEL(Abs, DT_DOUBLE, abs_f64, double);
REGISTER_AND_GENERATE_KERNEL(Abs, DT_INT32, abs_i32, int32);
REGISTER_AND_GENERATE_KERNEL(Abs, DT_INT64, abs_i64, int64);
REGISTER_AND_GENERATE_KERNEL(Abs, f16, DT_HALF, Eigen::half);
REGISTER_AND_GENERATE_KERNEL(Abs, f32, DT_FLOAT, float);
REGISTER_AND_GENERATE_KERNEL(Abs, f64, DT_DOUBLE, double);
REGISTER_AND_GENERATE_KERNEL(Abs, i32, DT_INT32, int32);
REGISTER_AND_GENERATE_KERNEL(Abs, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -72,45 +72,45 @@ Tensor ConvertDescriptorToTensor(
return tensor;
}
#define MLIR_FUNCTION(mlir_func) _mlir_ciface_##mlir_func
#define MLIR_FUNCTION(tf_op, mlir_type) _mlir_ciface_##tf_op##_##mlir_type
// Generates a class derived from OpKernel with Compute function that converts
// input tensors to unranked memref descriptors and calls mlir-generated
// unranked kernel. The outputs are converted back to tensors using
// MlirTensorBuffer to take ownership of pre-allocated memory.
#define REGISTER_AND_GENERATE_KERNEL(tf_op, tf_data_type, mlir_func, \
data_type) \
extern "C" ::UnrankedMemRefType<data_type> MLIR_FUNCTION(mlir_func)( \
tensorflow::OpKernelContext * ctx, \
::UnrankedMemRefType<data_type> * arg); \
\
namespace { \
class MlirUnranked##tf_op##mlir_func##Op : public OpKernel { \
public: \
MlirUnranked##tf_op##mlir_func##Op(OpKernelConstruction* ctx) \
: OpKernel(ctx) {} \
\
void Compute(OpKernelContext* ctx) override { \
const Tensor& input = ctx->input(0); \
\
auto input_desc = ConvertTensorToDescriptor<data_type>(input); \
auto result_desc = MLIR_FUNCTION(mlir_func)(ctx, &input_desc); \
free(input_desc.descriptor); \
\
tensorflow::AllocatorAttributes attrs; \
auto* allocator = ctx->get_allocator(attrs); \
\
Tensor result_tensor = ConvertDescriptorToTensor<data_type>( \
result_desc, tf_data_type, allocator); \
free(result_desc.descriptor); \
ctx->set_output(0, result_tensor); \
} \
}; \
} \
\
REGISTER_KERNEL_BUILDER( \
Name(#tf_op).Device(DEVICE_GPU).TypeConstraint<data_type>("T"), \
MlirUnranked##tf_op##mlir_func##Op);
#define REGISTER_AND_GENERATE_KERNEL(tf_op, mlir_type, tf_data_type, \
data_type) \
extern "C" ::UnrankedMemRefType<data_type> MLIR_FUNCTION(tf_op, mlir_type)( \
tensorflow::OpKernelContext * ctx, \
::UnrankedMemRefType<data_type> * arg); \
\
namespace { \
class MlirUnranked##tf_op##mlir_type##Op : public OpKernel { \
public: \
MlirUnranked##tf_op##mlir_type##Op(OpKernelConstruction* ctx) \
: OpKernel(ctx) {} \
\
void Compute(OpKernelContext* ctx) override { \
const Tensor& input = ctx->input(0); \
\
auto input_desc = ConvertTensorToDescriptor<data_type>(input); \
auto result_desc = MLIR_FUNCTION(tf_op, mlir_type)(ctx, &input_desc); \
free(input_desc.descriptor); \
\
tensorflow::AllocatorAttributes attrs; \
auto* allocator = ctx->get_allocator(attrs); \
\
Tensor result_tensor = ConvertDescriptorToTensor<data_type>( \
result_desc, tf_data_type, allocator); \
free(result_desc.descriptor); \
ctx->set_output(0, result_tensor); \
} \
}; \
} \
\
REGISTER_KERNEL_BUILDER( \
Name(#tf_op).Device(DEVICE_GPU).TypeConstraint<data_type>("T"), \
MlirUnranked##tf_op##mlir_type##Op);
} // namespace tensorflow

View File

@ -18,8 +18,8 @@ limitations under the License.
namespace tensorflow {
REGISTER_AND_GENERATE_KERNEL(Tanh, DT_HALF, tanh_f16, Eigen::half);
REGISTER_AND_GENERATE_KERNEL(Tanh, DT_FLOAT, tanh_f32, float);
REGISTER_AND_GENERATE_KERNEL(Tanh, DT_DOUBLE, tanh_f64, double);
REGISTER_AND_GENERATE_KERNEL(Tanh, f16, DT_HALF, Eigen::half);
REGISTER_AND_GENERATE_KERNEL(Tanh, f32, DT_FLOAT, float);
REGISTER_AND_GENERATE_KERNEL(Tanh, f64, DT_DOUBLE, double);
} // namespace tensorflow