From a4c706fe0990fbbf0dc8b572674e8ebfeb64869e Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 30 Jul 2020 07:05:53 -0700 Subject: [PATCH] Add MLIR generated abs kernel for the GPU backend Extract a common base class from the Tanh op that is used for both Tanh and Abs. The abs kernel is also behind the tensorflow_enable_mlir_generated_gpu_kernels flag. PiperOrigin-RevId: 323995229 Change-Id: I58b9e5c7dc0f428983c14bef23aa806ede6cf36c --- tensorflow/core/kernels/cwise_op_abs.cc | 4 + tensorflow/core/kernels/mlir_generated/BUILD | 42 ++++- .../mlir_generated/cwise_op_gpu_abs.cu.cc | 40 +++++ .../mlir_generated/cwise_op_gpu_base.cu.cc | 129 +++++++++++++++ .../mlir_generated/cwise_op_gpu_base.cu.h | 77 +++++++++ .../mlir_generated/cwise_op_gpu_tanh.cu.cc | 152 +----------------- .../kernels/mlir_generated/gpu_abs_test.cc | 95 +++++++++++ .../op_definitions/abs.mlir.tmpl | 2 +- .../op_definitions/tanh.mlir.tmpl | 2 +- 9 files changed, 394 insertions(+), 149 deletions(-) create mode 100644 tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cu.cc create mode 100644 tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.cc create mode 100644 tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h create mode 100644 tensorflow/core/kernels/mlir_generated/gpu_abs_test.cc diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc index e4f01cf6c90..d3b09f7078a 100644 --- a/tensorflow/core/kernels/cwise_op_abs.cc +++ b/tensorflow/core/kernels/cwise_op_abs.cc @@ -21,12 +21,15 @@ REGISTER8(UnaryOp, CPU, "Abs", functor::abs, Eigen::half, bfloat16, float, REGISTER2(UnaryOp, CPU, "ComplexAbs", functor::abs, complex64, complex128); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED REGISTER4(UnaryOp, GPU, "Abs", functor::abs, Eigen::half, float, double, int64); +#endif REGISTER2(UnaryOp, GPU, "ComplexAbs", functor::abs, complex64, complex128); // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. +#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED REGISTER_KERNEL_BUILDER(Name("Abs") .Device(DEVICE_GPU) .HostMemory("x") @@ -34,6 +37,7 @@ REGISTER_KERNEL_BUILDER(Name("Abs") .TypeConstraint("T"), UnaryOp>); #endif +#endif #if TENSORFLOW_USE_SYCL REGISTER3(UnaryOp, SYCL, "Abs", functor::abs, float, double, int64); diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index 7ee635fce32..9f3efe9d972 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -26,11 +26,18 @@ config_setting( tf_kernel_library( name = "cwise_op", - gpu_srcs = ["cwise_op_gpu_tanh.cu.cc"], + gpu_srcs = [ + "cwise_op_gpu_base.cu.cc", + "cwise_op_gpu_base.cu.h", + "cwise_op_gpu_abs.cu.cc", + "cwise_op_gpu_tanh.cu.cc", + ], tags = ["manual"], deps = if_cuda([ + ":abs_kernels", ":tanh_kernels", "@com_google_absl//absl/strings", + "//third_party/eigen3", "@com_google_absl//absl/types:span", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -57,6 +64,25 @@ tf_cuda_cc_test( ], ) +tf_cuda_cc_test( + name = "gpu_abs_test", + size = "small", + srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_abs_test.cc"]), + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:device", + "//tensorflow/core/common_runtime:device_factory", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:ops_testutil", + ], +) + # TODO(b/160731748): Re-enable when it works again. # gen_kernel_library( # name = "bias_add", @@ -92,3 +118,17 @@ gen_kernel_library( ], unroll_factors = "4", ) + +gen_kernel_library( + name = "abs", + same_shape = "0,1", + tile_size = "256", + types = [ + "f16", + "f32", + "f64", + "i32", + "i64", + ], + unroll_factors = "4", +) diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cu.cc b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cu.cc new file mode 100644 index 00000000000..1920317a7ae --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cu.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/mlir_generated/abs_f16_kernel.h" +#include "tensorflow/core/kernels/mlir_generated/abs_f32_kernel.h" +#include "tensorflow/core/kernels/mlir_generated/abs_f64_kernel.h" +#include "tensorflow/core/kernels/mlir_generated/abs_i32_kernel.h" +#include "tensorflow/core/kernels/mlir_generated/abs_i64_kernel.h" +#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h" + +namespace tensorflow { +namespace { +GENERATE_OP_KERNEL_BASE(Abs); +} // namespace + +REGISTER_AND_GENERATE_KERNEL(Abs, F16, Eigen::half); +REGISTER_AND_GENERATE_KERNEL(Abs, F32, float); +REGISTER_AND_GENERATE_KERNEL(Abs, F64, double); +REGISTER_AND_GENERATE_KERNEL(Abs, I32, int32); +REGISTER_AND_GENERATE_KERNEL(Abs, I64, int64); +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.cc b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.cc new file mode 100644 index 00000000000..5a5c9ed6a42 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.cc @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { +namespace { +Status CreateKernel(absl::string_view kernel_name, uint64_t num_args, + absl::string_view ptx, absl::Span cubin_data, + se::StreamExecutor* stream_exec, + std::unique_ptr& kernel_base) { + se::MultiKernelLoaderSpec loader_spec(num_args); + + if (!cubin_data.empty()) { + loader_spec.AddCudaCubinInMemory( + reinterpret_cast(cubin_data.data()), kernel_name); + } + + kernel_base.reset(new se::KernelBase(stream_exec)); + return stream_exec->GetKernel(loader_spec, kernel_base.get()); +} + +struct LaunchConfig { + se::BlockDim blockDim; + se::ThreadDim threadDim; +}; + +LaunchConfig GetLaunchConfiguration(std::vector tile_sizes, + std::vector unrolling_factors, + std::vector shape) { + LaunchConfig result; + // Ensure the vectors are length 3 and pad with ones. + tile_sizes.resize(3, 1); + unrolling_factors.resize(3, 1); + shape.resize(3, 1); + // The number of threads is given by the tiling size. + result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]); + // We know that the kernel was generated by mapping the three outer-most + // dimensions to x,y,z dimensions. So we only need to compute those. + std::vector block_dims(3); + for (int i = 0; i < 3; ++i) { + // Compute the number of grids. We use ceildiv here as we have to allocate + // an extra thread/block if the division is not even. The kernel contains + // code to handle the boundaries. + uint64 number_of_threads = Eigen::divup(shape[i], unrolling_factors[i]); + int number_of_grids = Eigen::divup(number_of_threads, tile_sizes[i]); + block_dims[i] = number_of_grids; + } + result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]); + return result; +} +} // namespace + +void MlirGeneratedUnaryOp::Compute(OpKernelContext* ctx) { + auto* stream = ctx->op_device_context()->stream(); + se::KernelBase* kernel; + { + absl::MutexLock l(&mu_); + if (!kernel_) { + OP_REQUIRES_OK(ctx, CreateKernel(name_, 10, "", cubin_data_, + stream->parent(), kernel_)); + } + kernel = kernel_.get(); + } + + const Tensor& inp = ctx->input(0); + Tensor* out = nullptr; + OP_REQUIRES_OK( + ctx, ctx->forward_input_or_allocate_output({0}, 0, inp.shape(), &out)); + + if (inp.NumElements() == 0) { + return; + } + + se::KernelArgsArray<10> args; + + args.add_device_memory_argument( + stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes())); + args.add_device_memory_argument( + stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes())); + args.add_argument(0); + args.add_argument(inp.NumElements()); + args.add_argument(1); + + args.add_device_memory_argument( + stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes())); + args.add_device_memory_argument( + stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes())); + args.add_argument(0); + args.add_argument(inp.NumElements()); + args.add_argument(1); + + // This has to be aligned with the configuration that was used when building + // the kernels. See the corresponding build rules in the `BUILD` file. + LaunchConfig config = GetLaunchConfiguration( + {256}, {4}, {static_cast(inp.NumElements())}); + OP_REQUIRES_OK(ctx, stream->parent()->Launch(stream, config.threadDim, + config.blockDim, *kernel, args)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h new file mode 100644 index 00000000000..4e75aab6e16 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_CU_H_ +#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_CU_H_ + +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { +class MlirGeneratedUnaryOp : public OpKernel { + public: + MlirGeneratedUnaryOp(OpKernelConstruction* ctx, std::string name, + absl::Span cubin_data) + : OpKernel(ctx), name_(name), cubin_data_(cubin_data) {} + + void Compute(OpKernelContext* ctx) override; + + private: + std::string name_; + absl::Span cubin_data_; + std::unique_ptr kernel_; + absl::Mutex mu_; +}; + +#define GENERATE_OP_KERNEL_BASE(kernel_name) \ + class MlirGenerated##kernel_name##Op : public MlirGeneratedUnaryOp { \ + public: \ + MlirGenerated##kernel_name##Op(OpKernelConstruction* ctx, \ + absl::Span cubin_data) \ + : MlirGeneratedUnaryOp(ctx, \ + absl::AsciiStrToLower(#kernel_name "_kernel"), \ + cubin_data) {} \ + }; + +#define GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \ + class MlirGenerated##kernel_name##data_type##Op \ + : public MlirGenerated##kernel_name##Op { \ + public: \ + explicit MlirGenerated##kernel_name##data_type##Op( \ + OpKernelConstruction* ctx) \ + : MlirGenerated##kernel_name \ + ##Op(ctx, k##kernel_name##data_type##Kernel) {} \ + }; + +#define REGISTER_AND_GENERATE_KERNEL(kernel_name, data_type, native_data_type) \ + namespace { \ + GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \ + } \ + REGISTER_KERNEL_BUILDER(Name(#kernel_name) \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + MlirGenerated##kernel_name##data_type##Op); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_CU_H_ diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cu.cc b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cu.cc index 9a2881be284..b113c4cad34 100644 --- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cu.cc +++ b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cu.cc @@ -13,164 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h" #include "tensorflow/core/kernels/mlir_generated/tanh_f16_kernel.h" #include "tensorflow/core/kernels/mlir_generated/tanh_f32_kernel.h" #include "tensorflow/core/kernels/mlir_generated/tanh_f64_kernel.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor.h" namespace tensorflow { namespace { -Status CreateKernel(absl::string_view kernel_name, uint64_t num_args, - absl::string_view ptx, absl::Span cubin_data, - se::StreamExecutor* stream_exec, - std::unique_ptr& kernel_base) { - se::MultiKernelLoaderSpec loader_spec(num_args); - - if (!cubin_data.empty()) { - loader_spec.AddCudaCubinInMemory( - reinterpret_cast(cubin_data.data()), kernel_name); - } - - kernel_base.reset(new se::KernelBase(stream_exec)); - return stream_exec->GetKernel(loader_spec, kernel_base.get()); -} - -struct LaunchConfig { - se::BlockDim blockDim; - se::ThreadDim threadDim; -}; - -LaunchConfig GetLaunchConfiguration(std::vector tile_sizes, - std::vector unrolling_factors, - std::vector shape) { - LaunchConfig result; - // Ensure the vectors are length 3 and pad with ones. - tile_sizes.resize(3, 1); - unrolling_factors.resize(3, 1); - shape.resize(3, 1); - // The number of threads is given by the tiling size. - result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]); - // We know that the kernel was generated by mapping the three outer-most - // dimensions to x,y,z dimensions. So we only need to compute those. - std::vector block_dims(3); - for (int i = 0; i < 3; ++i) { - // Compute the number of grids. We use ceildiv here as we have to allocate - // an extra thread/block if the division is not even. The kernel contains - // code to handle the boundaries. - int number_of_threads = - (shape[i] + unrolling_factors[i] - 1) / unrolling_factors[i]; - int number_of_grids = - (number_of_threads + tile_sizes[i] - 1) / tile_sizes[i]; - block_dims[i] = number_of_grids; - } - result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]); - return result; -} - -class MlirGeneratedTanhOp : public OpKernel { - public: - explicit MlirGeneratedTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - auto* stream = ctx->op_device_context()->stream(); - se::KernelBase* kernel; - { - std::lock_guard l(mu_); - if (!kernel_) { - OP_REQUIRES_OK(ctx, CreateKernel("Tanh_kernel", 10, "", cubin_data_, - stream->parent(), kernel_)); - } - kernel = kernel_.get(); - } - - const Tensor& inp = ctx->input(0); - Tensor* out = nullptr; - OP_REQUIRES_OK( - ctx, ctx->forward_input_or_allocate_output({0}, 0, inp.shape(), &out)); - - if (inp.NumElements() == 0) { - return; - } - - se::KernelArgsArray<10> args; - - args.add_device_memory_argument( - stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes())); - args.add_device_memory_argument( - stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes())); - args.add_argument(0); - args.add_argument(inp.NumElements()); - args.add_argument(1); - - args.add_device_memory_argument( - stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes())); - args.add_device_memory_argument( - stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes())); - args.add_argument(0); - args.add_argument(inp.NumElements()); - args.add_argument(1); - - // This has to be aligned with the configuration that was used when - // generating the kernels. See the corresponding build rules in the `BUILD` - // file. - LaunchConfig config = GetLaunchConfiguration( - {256}, {4}, {static_cast(inp.NumElements())}); - OP_REQUIRES_OK( - ctx, stream->parent()->Launch(stream, config.threadDim, config.blockDim, - *kernel, args)); - } - - protected: - absl::Span cubin_data_; - - private: - std::unique_ptr kernel_; - std::mutex mu_; -}; - -class MlirGeneratedTanhF16Op : public MlirGeneratedTanhOp { - public: - explicit MlirGeneratedTanhF16Op(OpKernelConstruction* ctx) - : MlirGeneratedTanhOp(ctx) { - cubin_data_ = kTanhF16Kernel; - } -}; - -class MlirGeneratedTanhF32Op : public MlirGeneratedTanhOp { - public: - explicit MlirGeneratedTanhF32Op(OpKernelConstruction* ctx) - : MlirGeneratedTanhOp(ctx) { - cubin_data_ = kTanhF32Kernel; - } -}; - -class MlirGeneratedTanhF64Op : public MlirGeneratedTanhOp { - public: - explicit MlirGeneratedTanhF64Op(OpKernelConstruction* ctx) - : MlirGeneratedTanhOp(ctx) { - cubin_data_ = kTanhF64Kernel; - } -}; +GENERATE_OP_KERNEL_BASE(Tanh); } // namespace -REGISTER_KERNEL_BUILDER( - Name("Tanh").Device(DEVICE_GPU).TypeConstraint("T"), - MlirGeneratedTanhF16Op); -REGISTER_KERNEL_BUILDER( - Name("Tanh").Device(DEVICE_GPU).TypeConstraint("T"), - MlirGeneratedTanhF32Op); -REGISTER_KERNEL_BUILDER( - Name("Tanh").Device(DEVICE_GPU).TypeConstraint("T"), - MlirGeneratedTanhF64Op); +REGISTER_AND_GENERATE_KERNEL(Tanh, F16, Eigen::half) +REGISTER_AND_GENERATE_KERNEL(Tanh, F32, float) +REGISTER_AND_GENERATE_KERNEL(Tanh, F64, double) } // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/gpu_abs_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_abs_test.cc new file mode 100644 index 00000000000..ae76c023440 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/gpu_abs_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class GpuAbsTest : public OpsTestBase { + protected: + void SetUp() override { + std::unique_ptr device_gpu( + tensorflow::DeviceFactory::NewDevice("GPU", {}, + "/job:a/replica:0/task:0")); + SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu)); + } + template + void RunAbsOp(std::initializer_list input) { + TensorShape shape({2, 3}); + TF_ASSERT_OK(NodeDefBuilder("abs_op", "Abs") + .Input(FakeInput(DataTypeToEnum::v())) + .Attr("T", DataTypeToEnum::v()) + .Finalize(node_def())); + + TF_ASSERT_OK(InitOp()); + AddInputFromArray(shape, input); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected_tensor(allocator(), DataTypeToEnum::value, shape); + std::vector expected; + expected.reserve(input.size()); + for (const T& inp : input) { + expected.push_back(static_cast(std::abs(static_cast(inp)))); + } + test::FillValues(&expected_tensor, expected); + test::ExpectEqual(expected_tensor, *GetOutput(0)); + } +}; + +TEST_F(GpuAbsTest, AbsFloat) { + RunAbsOp({-std::numeric_limits::infinity(), -0.1f, -0.0f, 0.0f, + 0.1f, std::numeric_limits::infinity()}); +} + +TEST_F(GpuAbsTest, AbsDouble) { + RunAbsOp({-std::numeric_limits::infinity(), -0.1, -0.0, 0.0, + 0.1, std::numeric_limits::infinity()}); +} + +TEST_F(GpuAbsTest, AbsHalf) { + RunAbsOp( + {static_cast(-std::numeric_limits::infinity()), + static_cast(-0.1), static_cast(-0.0), + static_cast(0.0), static_cast(0.1), + static_cast(std::numeric_limits::infinity())}); +} + +TEST_F(GpuAbsTest, AbsInt32) { + RunAbsOp({std::numeric_limits::min(), + std::numeric_limits::min() + 1, -1, 0, 1, + std::numeric_limits::max()}); +} + +TEST_F(GpuAbsTest, AbsInt64) { + RunAbsOp({std::numeric_limits::min(), + std::numeric_limits::min() + 1, -1, 0, 1, + std::numeric_limits::max()}); +} + +} // namespace +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl index bca0c59cd77..d4c9bd5eaed 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl @@ -1,4 +1,4 @@ -func @Abs(%arg0: tensor) -> tensor { +func @abs(%arg0: tensor) -> tensor { %0 = "tf.Abs"(%arg0) { } : (tensor) -> tensor return %0 : tensor diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl index a6000604210..58a9b61ef68 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl @@ -1,4 +1,4 @@ -func @Tanh(%arg0: tensor) -> tensor { +func @tanh(%arg0: tensor) -> tensor { %0 = "tf.Tanh"(%arg0) { } : (tensor) -> tensor return %0 : tensor