Add MLIR generated abs kernel for the GPU backend

Extract a common base class from the Tanh op that is
used for both Tanh and Abs.
The abs kernel is also behind the
tensorflow_enable_mlir_generated_gpu_kernels flag.

PiperOrigin-RevId: 323995229
Change-Id: I58b9e5c7dc0f428983c14bef23aa806ede6cf36c
This commit is contained in:
Adrian Kuegel 2020-07-30 07:05:53 -07:00 committed by TensorFlower Gardener
parent d2109607f0
commit a4c706fe09
9 changed files with 394 additions and 149 deletions

View File

@ -21,12 +21,15 @@ REGISTER8(UnaryOp, CPU, "Abs", functor::abs, Eigen::half, bfloat16, float,
REGISTER2(UnaryOp, CPU, "ComplexAbs", functor::abs, complex64, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
REGISTER4(UnaryOp, GPU, "Abs", functor::abs, Eigen::half, float, double, int64);
#endif
REGISTER2(UnaryOp, GPU, "ComplexAbs", functor::abs, complex64, complex128);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
REGISTER_KERNEL_BUILDER(Name("Abs")
.Device(DEVICE_GPU)
.HostMemory("x")
@ -34,6 +37,7 @@ REGISTER_KERNEL_BUILDER(Name("Abs")
.TypeConstraint<int32>("T"),
UnaryOp<CPUDevice, functor::abs<int32>>);
#endif
#endif
#if TENSORFLOW_USE_SYCL
REGISTER3(UnaryOp, SYCL, "Abs", functor::abs, float, double, int64);

View File

@ -26,11 +26,18 @@ config_setting(
tf_kernel_library(
name = "cwise_op",
gpu_srcs = ["cwise_op_gpu_tanh.cu.cc"],
gpu_srcs = [
"cwise_op_gpu_base.cu.cc",
"cwise_op_gpu_base.cu.h",
"cwise_op_gpu_abs.cu.cc",
"cwise_op_gpu_tanh.cu.cc",
],
tags = ["manual"],
deps = if_cuda([
":abs_kernels",
":tanh_kernels",
"@com_google_absl//absl/strings",
"//third_party/eigen3",
"@com_google_absl//absl/types:span",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -57,6 +64,25 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "gpu_abs_test",
size = "small",
srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_abs_test.cc"]),
tags = tf_cuda_tests_tags() + ["no_rocm"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/common_runtime:device",
"//tensorflow/core/common_runtime:device_factory",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:ops_testutil",
],
)
# TODO(b/160731748): Re-enable when it works again.
# gen_kernel_library(
# name = "bias_add",
@ -92,3 +118,17 @@ gen_kernel_library(
],
unroll_factors = "4",
)
gen_kernel_library(
name = "abs",
same_shape = "0,1",
tile_size = "256",
types = [
"f16",
"f32",
"f64",
"i32",
"i64",
],
unroll_factors = "4",
)

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/mlir_generated/abs_f16_kernel.h"
#include "tensorflow/core/kernels/mlir_generated/abs_f32_kernel.h"
#include "tensorflow/core/kernels/mlir_generated/abs_f64_kernel.h"
#include "tensorflow/core/kernels/mlir_generated/abs_i32_kernel.h"
#include "tensorflow/core/kernels/mlir_generated/abs_i64_kernel.h"
#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h"
namespace tensorflow {
namespace {
GENERATE_OP_KERNEL_BASE(Abs);
} // namespace
REGISTER_AND_GENERATE_KERNEL(Abs, F16, Eigen::half);
REGISTER_AND_GENERATE_KERNEL(Abs, F32, float);
REGISTER_AND_GENERATE_KERNEL(Abs, F64, double);
REGISTER_AND_GENERATE_KERNEL(Abs, I32, int32);
REGISTER_AND_GENERATE_KERNEL(Abs, I64, int64);
} // namespace tensorflow

View File

@ -0,0 +1,129 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h"
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
namespace {
Status CreateKernel(absl::string_view kernel_name, uint64_t num_args,
absl::string_view ptx, absl::Span<const uint8_t> cubin_data,
se::StreamExecutor* stream_exec,
std::unique_ptr<se::KernelBase>& kernel_base) {
se::MultiKernelLoaderSpec loader_spec(num_args);
if (!cubin_data.empty()) {
loader_spec.AddCudaCubinInMemory(
reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
}
kernel_base.reset(new se::KernelBase(stream_exec));
return stream_exec->GetKernel(loader_spec, kernel_base.get());
}
struct LaunchConfig {
se::BlockDim blockDim;
se::ThreadDim threadDim;
};
LaunchConfig GetLaunchConfiguration(std::vector<uint64> tile_sizes,
std::vector<uint64> unrolling_factors,
std::vector<uint64> shape) {
LaunchConfig result;
// Ensure the vectors are length 3 and pad with ones.
tile_sizes.resize(3, 1);
unrolling_factors.resize(3, 1);
shape.resize(3, 1);
// The number of threads is given by the tiling size.
result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]);
// We know that the kernel was generated by mapping the three outer-most
// dimensions to x,y,z dimensions. So we only need to compute those.
std::vector<int> block_dims(3);
for (int i = 0; i < 3; ++i) {
// Compute the number of grids. We use ceildiv here as we have to allocate
// an extra thread/block if the division is not even. The kernel contains
// code to handle the boundaries.
uint64 number_of_threads = Eigen::divup(shape[i], unrolling_factors[i]);
int number_of_grids = Eigen::divup(number_of_threads, tile_sizes[i]);
block_dims[i] = number_of_grids;
}
result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]);
return result;
}
} // namespace
void MlirGeneratedUnaryOp::Compute(OpKernelContext* ctx) {
auto* stream = ctx->op_device_context()->stream();
se::KernelBase* kernel;
{
absl::MutexLock l(&mu_);
if (!kernel_) {
OP_REQUIRES_OK(ctx, CreateKernel(name_, 10, "", cubin_data_,
stream->parent(), kernel_));
}
kernel = kernel_.get();
}
const Tensor& inp = ctx->input(0);
Tensor* out = nullptr;
OP_REQUIRES_OK(
ctx, ctx->forward_input_or_allocate_output({0}, 0, inp.shape(), &out));
if (inp.NumElements() == 0) {
return;
}
se::KernelArgsArray<10> args;
args.add_device_memory_argument(
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
args.add_device_memory_argument(
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
args.add_argument<int64_t>(0);
args.add_argument<int64_t>(inp.NumElements());
args.add_argument<int64_t>(1);
args.add_device_memory_argument(
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
args.add_device_memory_argument(
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
args.add_argument<int64_t>(0);
args.add_argument<int64_t>(inp.NumElements());
args.add_argument<int64_t>(1);
// This has to be aligned with the configuration that was used when building
// the kernels. See the corresponding build rules in the `BUILD` file.
LaunchConfig config = GetLaunchConfiguration(
{256}, {4}, {static_cast<uint64>(inp.NumElements())});
OP_REQUIRES_OK(ctx, stream->parent()->Launch(stream, config.threadDim,
config.blockDim, *kernel, args));
}
} // namespace tensorflow

View File

@ -0,0 +1,77 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_CU_H_
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_CU_H_
#include <memory>
#include <string>
#include "absl/strings/ascii.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
class MlirGeneratedUnaryOp : public OpKernel {
public:
MlirGeneratedUnaryOp(OpKernelConstruction* ctx, std::string name,
absl::Span<const uint8_t> cubin_data)
: OpKernel(ctx), name_(name), cubin_data_(cubin_data) {}
void Compute(OpKernelContext* ctx) override;
private:
std::string name_;
absl::Span<const uint8_t> cubin_data_;
std::unique_ptr<se::KernelBase> kernel_;
absl::Mutex mu_;
};
#define GENERATE_OP_KERNEL_BASE(kernel_name) \
class MlirGenerated##kernel_name##Op : public MlirGeneratedUnaryOp { \
public: \
MlirGenerated##kernel_name##Op(OpKernelConstruction* ctx, \
absl::Span<const uint8_t> cubin_data) \
: MlirGeneratedUnaryOp(ctx, \
absl::AsciiStrToLower(#kernel_name "_kernel"), \
cubin_data) {} \
};
#define GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \
class MlirGenerated##kernel_name##data_type##Op \
: public MlirGenerated##kernel_name##Op { \
public: \
explicit MlirGenerated##kernel_name##data_type##Op( \
OpKernelConstruction* ctx) \
: MlirGenerated##kernel_name \
##Op(ctx, k##kernel_name##data_type##Kernel) {} \
};
#define REGISTER_AND_GENERATE_KERNEL(kernel_name, data_type, native_data_type) \
namespace { \
GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \
} \
REGISTER_KERNEL_BUILDER(Name(#kernel_name) \
.Device(DEVICE_GPU) \
.TypeConstraint<native_data_type>("T"), \
MlirGenerated##kernel_name##data_type##Op);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_CU_H_

View File

@ -13,164 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cu.h"
#include "tensorflow/core/kernels/mlir_generated/tanh_f16_kernel.h"
#include "tensorflow/core/kernels/mlir_generated/tanh_f32_kernel.h"
#include "tensorflow/core/kernels/mlir_generated/tanh_f64_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
namespace {
Status CreateKernel(absl::string_view kernel_name, uint64_t num_args,
absl::string_view ptx, absl::Span<const uint8_t> cubin_data,
se::StreamExecutor* stream_exec,
std::unique_ptr<se::KernelBase>& kernel_base) {
se::MultiKernelLoaderSpec loader_spec(num_args);
if (!cubin_data.empty()) {
loader_spec.AddCudaCubinInMemory(
reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
}
kernel_base.reset(new se::KernelBase(stream_exec));
return stream_exec->GetKernel(loader_spec, kernel_base.get());
}
struct LaunchConfig {
se::BlockDim blockDim;
se::ThreadDim threadDim;
};
LaunchConfig GetLaunchConfiguration(std::vector<uint64> tile_sizes,
std::vector<uint64> unrolling_factors,
std::vector<uint64> shape) {
LaunchConfig result;
// Ensure the vectors are length 3 and pad with ones.
tile_sizes.resize(3, 1);
unrolling_factors.resize(3, 1);
shape.resize(3, 1);
// The number of threads is given by the tiling size.
result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]);
// We know that the kernel was generated by mapping the three outer-most
// dimensions to x,y,z dimensions. So we only need to compute those.
std::vector<int> block_dims(3);
for (int i = 0; i < 3; ++i) {
// Compute the number of grids. We use ceildiv here as we have to allocate
// an extra thread/block if the division is not even. The kernel contains
// code to handle the boundaries.
int number_of_threads =
(shape[i] + unrolling_factors[i] - 1) / unrolling_factors[i];
int number_of_grids =
(number_of_threads + tile_sizes[i] - 1) / tile_sizes[i];
block_dims[i] = number_of_grids;
}
result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]);
return result;
}
class MlirGeneratedTanhOp : public OpKernel {
public:
explicit MlirGeneratedTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
auto* stream = ctx->op_device_context()->stream();
se::KernelBase* kernel;
{
std::lock_guard<std::mutex> l(mu_);
if (!kernel_) {
OP_REQUIRES_OK(ctx, CreateKernel("Tanh_kernel", 10, "", cubin_data_,
stream->parent(), kernel_));
}
kernel = kernel_.get();
}
const Tensor& inp = ctx->input(0);
Tensor* out = nullptr;
OP_REQUIRES_OK(
ctx, ctx->forward_input_or_allocate_output({0}, 0, inp.shape(), &out));
if (inp.NumElements() == 0) {
return;
}
se::KernelArgsArray<10> args;
args.add_device_memory_argument(
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
args.add_device_memory_argument(
stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes()));
args.add_argument<int64_t>(0);
args.add_argument<int64_t>(inp.NumElements());
args.add_argument<int64_t>(1);
args.add_device_memory_argument(
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
args.add_device_memory_argument(
stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes()));
args.add_argument<int64_t>(0);
args.add_argument<int64_t>(inp.NumElements());
args.add_argument<int64_t>(1);
// This has to be aligned with the configuration that was used when
// generating the kernels. See the corresponding build rules in the `BUILD`
// file.
LaunchConfig config = GetLaunchConfiguration(
{256}, {4}, {static_cast<uint64>(inp.NumElements())});
OP_REQUIRES_OK(
ctx, stream->parent()->Launch(stream, config.threadDim, config.blockDim,
*kernel, args));
}
protected:
absl::Span<const uint8_t> cubin_data_;
private:
std::unique_ptr<se::KernelBase> kernel_;
std::mutex mu_;
};
class MlirGeneratedTanhF16Op : public MlirGeneratedTanhOp {
public:
explicit MlirGeneratedTanhF16Op(OpKernelConstruction* ctx)
: MlirGeneratedTanhOp(ctx) {
cubin_data_ = kTanhF16Kernel;
}
};
class MlirGeneratedTanhF32Op : public MlirGeneratedTanhOp {
public:
explicit MlirGeneratedTanhF32Op(OpKernelConstruction* ctx)
: MlirGeneratedTanhOp(ctx) {
cubin_data_ = kTanhF32Kernel;
}
};
class MlirGeneratedTanhF64Op : public MlirGeneratedTanhOp {
public:
explicit MlirGeneratedTanhF64Op(OpKernelConstruction* ctx)
: MlirGeneratedTanhOp(ctx) {
cubin_data_ = kTanhF64Kernel;
}
};
GENERATE_OP_KERNEL_BASE(Tanh);
} // namespace
REGISTER_KERNEL_BUILDER(
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
MlirGeneratedTanhF16Op);
REGISTER_KERNEL_BUILDER(
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"),
MlirGeneratedTanhF32Op);
REGISTER_KERNEL_BUILDER(
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<double>("T"),
MlirGeneratedTanhF64Op);
REGISTER_AND_GENERATE_KERNEL(Tanh, F16, Eigen::half)
REGISTER_AND_GENERATE_KERNEL(Tanh, F32, float)
REGISTER_AND_GENERATE_KERNEL(Tanh, F64, double)
} // namespace tensorflow

View File

@ -0,0 +1,95 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <limits>
#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class GpuAbsTest : public OpsTestBase {
protected:
void SetUp() override {
std::unique_ptr<tensorflow::Device> device_gpu(
tensorflow::DeviceFactory::NewDevice("GPU", {},
"/job:a/replica:0/task:0"));
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
}
template <typename T, typename RT = T>
void RunAbsOp(std::initializer_list<T> input) {
TensorShape shape({2, 3});
TF_ASSERT_OK(NodeDefBuilder("abs_op", "Abs")
.Input(FakeInput(DataTypeToEnum<T>::v()))
.Attr("T", DataTypeToEnum<T>::v())
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<T>(shape, input);
TF_ASSERT_OK(RunOpKernel());
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
std::vector<T> expected;
expected.reserve(input.size());
for (const T& inp : input) {
expected.push_back(static_cast<T>(std::abs(static_cast<RT>(inp))));
}
test::FillValues<T>(&expected_tensor, expected);
test::ExpectEqual(expected_tensor, *GetOutput(0));
}
};
TEST_F(GpuAbsTest, AbsFloat) {
RunAbsOp<float>({-std::numeric_limits<float>::infinity(), -0.1f, -0.0f, 0.0f,
0.1f, std::numeric_limits<float>::infinity()});
}
TEST_F(GpuAbsTest, AbsDouble) {
RunAbsOp<double>({-std::numeric_limits<double>::infinity(), -0.1, -0.0, 0.0,
0.1, std::numeric_limits<double>::infinity()});
}
TEST_F(GpuAbsTest, AbsHalf) {
RunAbsOp<Eigen::half, float>(
{static_cast<Eigen::half>(-std::numeric_limits<double>::infinity()),
static_cast<Eigen::half>(-0.1), static_cast<Eigen::half>(-0.0),
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(0.1),
static_cast<Eigen::half>(std::numeric_limits<double>::infinity())});
}
TEST_F(GpuAbsTest, AbsInt32) {
RunAbsOp<int32>({std::numeric_limits<int32>::min(),
std::numeric_limits<int32>::min() + 1, -1, 0, 1,
std::numeric_limits<int32>::max()});
}
TEST_F(GpuAbsTest, AbsInt64) {
RunAbsOp<int64>({std::numeric_limits<int64>::min(),
std::numeric_limits<int64>::min() + 1, -1, 0, 1,
std::numeric_limits<int64>::max()});
}
} // namespace
} // end namespace tensorflow

View File

@ -1,4 +1,4 @@
func @Abs(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
func @abs(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
%0 = "tf.Abs"(%arg0) { }
: (tensor<?xelem_type>) -> tensor<?xelem_type>
return %0 : tensor<?xelem_type>

View File

@ -1,4 +1,4 @@
func @Tanh(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
func @tanh(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
%0 = "tf.Tanh"(%arg0) { }
: (tensor<?xelem_type>) -> tensor<?xelem_type>
return %0 : tensor<?xelem_type>