diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 4cfb216b532..731e882ea25 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -86,29 +86,6 @@ cc_library( ], ) -tf_cc_binary( - name = "tf_to_gpu_binary", - srcs = [ - "crash_handler.h", - "tf_to_gpu_binary.cc", - ], - visibility = [ - "//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary:__pkg__", - "//tensorflow/core/kernels/mlir_generated:__pkg__", - ], - deps = [ - ":kernel_creator", - "//tensorflow/compiler/mlir:init_mlir", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/core:lib", - "//tensorflow/core/platform", - "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Pass", - ], -) - tf_cc_binary( name = "tf_to_kernel", srcs = ["tf_to_kernel.cc"], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 5221a87dfbd..192ef4f9cce 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -90,27 +90,17 @@ struct RemoveUnusedTensorToMemrefOperations }; } // end anonymous namespace -Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, - llvm::ArrayRef tile_sizes, +Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, bool embed_memref_prints) { mlir::PassManager pm(module.getContext()); applyTensorflowAndCLOptions(pm); - if (gpu_binary_only) { - pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( - /*allow_partial_conversion=*/false, /*legalize_chlo=*/true)); - pm.addNestedPass( - mlir::kernel_gen::transforms::CreateMaterializeBroadcastsPass()); - pm.addNestedPass( - mlir::kernel_gen::transforms::CreateUnfuseBatchNormPass()); - } else { - pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( - /*allow_partial_conversion=*/false, /*legalize_chlo=*/false)); - pm.addNestedPass(mlir::createTransformUnrankedHloPass()); - pm.addNestedPass(mlir::mhlo::createChloLegalizeToHloPass()); - pm.addNestedPass(mlir::createCanonicalizerPass()); - } + pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/false, /*legalize_chlo=*/false)); + pm.addNestedPass(mlir::createTransformUnrankedHloPass()); + pm.addNestedPass(mlir::mhlo::createChloLegalizeToHloPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); // Transform HLO operations to LinAlg. pm.addNestedPass(::mlir::mhlo::createLegalizeHloToLinalgPass()); @@ -139,12 +129,10 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, pm.addPass(mlir::kernel_gen::transforms::CreateHloBufferizePass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); - if (!gpu_binary_only) { - // Find candidates for buffer reuse. This is only successful if buffer size - // equality can be determined based on `linalg.generic` operations. - pm.addNestedPass( - mlir::kernel_gen::transforms::CreateBufferReusePass()); - } + // Find candidates for buffer reuse. This is only successful if buffer size + // equality can be determined based on `linalg.generic` operations. + pm.addNestedPass( + mlir::kernel_gen::transforms::CreateBufferReusePass()); pm.addNestedPass( mlir::createLinalgTilingToParallelLoopsPass((tiling_for_unrolling))); // Transform the Linalg ops inside of the loop nest into parallel loops. @@ -188,15 +176,13 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, std::make_unique()); pm.addPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); - if (!gpu_binary_only) { - // Before inserting more allocs, map the ones we already have to the - // tf runtime. That ensures that all allocations for the actual computation - // end up on the device, whereas allocations for shape computation and host - // side things remain on the host. - // Longer term, this should be handled by proper device placement. - pm.addPass(mlir::kernel_gen::tf_framework:: - CreateEmbedTFFrameworkFunctionAndAllocPass()); - } + // Before inserting more allocs, map the ones we already have to the + // tf runtime. That ensures that all allocations for the actual computation + // end up on the device, whereas allocations for shape computation and host + // side things remain on the host. + // Longer term, this should be handled by proper device placement. + pm.addPass(mlir::kernel_gen::tf_framework:: + CreateEmbedTFFrameworkFunctionAndAllocPass()); pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass()); pm.addNestedPass(mlir::createPromoteBuffersToStackPass(64)); // TODO(herhut): Depends on https://bugs.llvm.org/show_bug.cgi?id=48385. @@ -223,11 +209,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, // Take launches to launches with kernels. pm.addPass(::mlir::createGpuKernelOutliningPass()); - if (gpu_binary_only) { - // Make kernel signature deterministic so that we can call it externally. - pm.addNestedPass<::mlir::FuncOp>( - xla::mlir_gpu::createRewriteKernelSignaturePass()); - } pm.addPass(::mlir::createLowerAffinePass()); // Constraints are removed as late as possible and before lowering to CFG. pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass()); @@ -295,7 +276,7 @@ Status LowerHostSideToFinalForm(mlir::ModuleOp module) { } // namespace StatusOr GenerateKernelForTfCode( - mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, + mlir::MLIRContext& context, llvm::StringRef tf_code, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, bool embed_memref_prints, @@ -304,8 +285,8 @@ StatusOr GenerateKernelForTfCode( mlir::RegisterAllTensorFlowDialects(registry); registry.insert(); mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); - TF_RETURN_IF_ERROR(LowerTFtoGPU(module.get(), gpu_binary_only, tile_sizes, - unroll_factors, embed_memref_prints)); + TF_RETURN_IF_ERROR(LowerTFtoGPU(module.get(), tile_sizes, unroll_factors, + embed_memref_prints)); #if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA) return InternalError( "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined." @@ -321,9 +302,7 @@ StatusOr GenerateKernelForTfCode( TF_RETURN_IF_ERROR(GenerateDeviceCode(module.get(), kGpuBinaryAttrName, architectures, generate_fatbin, print_ptx)); - if (!gpu_binary_only) { - TF_RETURN_IF_ERROR(LowerHostSideToFinalForm(module.get())); - } + TF_RETURN_IF_ERROR(LowerHostSideToFinalForm(module.get())); return module; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index 33be8ae9ef2..ac8ce845713 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -33,11 +33,9 @@ limitations under the License. namespace tensorflow { namespace kernel_gen { -// Converts TF code to LLVM/NVVM. If `gpu_binary_only` is true, then the -// conversion stops after gpu_binary blob is generated. If `gpu_binary_only` is -// false, lowers the host side to LLVM Dialect. +// Converts TF code to LLVM/NVVM. Lowers the host side to LLVM Dialect. xla::StatusOr GenerateKernelForTfCode( - mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, + mlir::MLIRContext& context, llvm::StringRef tf_code, llvm::ArrayRef architectures = {"sm_75"}, llvm::ArrayRef tile_sizes = {16, 64}, llvm::ArrayRef unroll_factors = {}, diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/BUILD deleted file mode 100644 index 6aef5c05fe9..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") - -package(licenses = ["notice"]) - -glob_lit_tests( - data = [ - "//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_gpu_binary", - "@llvm-project//mlir:run_lit.sh", - ], - default_tags = [ - # We need access to the CUDA SDK. - "gpu", - "no_rocm", - ], - driver = "//tensorflow/compiler/mlir:run_lit.sh", - test_file_exts = ["mlir"], -) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir deleted file mode 100644 index 51773093564..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir +++ /dev/null @@ -1,6 +0,0 @@ -// RUN: tf_to_gpu_binary --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 -func @abs(%arg0: tensor) -> tensor attributes {tf_entry} { - %0 = "tf.Abs"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir deleted file mode 100644 index bb505809abe..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir +++ /dev/null @@ -1,6 +0,0 @@ -// RUN: tf_to_gpu_binary --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 -func @ceil(%arg0: tensor) -> tensor attributes {tf_entry} { - %0 = "tf.Ceil"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir deleted file mode 100644 index fa88fc76c90..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir +++ /dev/null @@ -1,5 +0,0 @@ -// RUN: tf_to_gpu_binary --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 -func @tanh(%arg0: tensor) -> tensor attributes {tf_entry} { - %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor - return %0 : tensor -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc deleted file mode 100644 index 6f1de7dc1bc..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2020 The TensorFlow Runtime Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- tf_to_gpu_binary.cc --------------------------------------*- C++ -*-===// -// -// This file implements the entry point to compile a tf op to a gpu binary -// -//===----------------------------------------------------------------------===// -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "llvm/Support/CommandLine.h" -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "tensorflow/compiler/mlir/init_mlir.h" -#include "tensorflow/compiler/mlir/tools/kernel_gen/crash_handler.h" -#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/stream_executor/lib/statusor.h" - -namespace tensorflow { -namespace kernel_gen { -namespace { - -xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, - std::string architecture, llvm::ArrayRef tile_sizes, - llvm::ArrayRef unroll_factors) { - // Read TF code. - std::string tf_code; - TF_RETURN_IF_ERROR( - ReadFileToString(Env::Default(), input_file.str(), &tf_code)); - // Compile. - mlir::MLIRContext context; - TF_ASSIGN_OR_RETURN( - mlir::OwningModuleRef module, - GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true, - architecture, tile_sizes, unroll_factors, - /*embed_memref_prints=*/false, - /*generate_fatbin=*/false)); - // Extract gpu_binary. - TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module)); - - // Write gpu_binary blob. - TF_RETURN_IF_ERROR( - WriteStringToFile(Env::Default(), output_file.str(), gpu_binary)); - return xla::Status::OK(); -} - -} // namespace -} // namespace kernel_gen -} // namespace tensorflow - -int main(int argc, char** argv) { - tensorflow::kernel_gen::SetCrashReportMessage(); - llvm::cl::opt input_file("input", llvm::cl::desc("input file"), - llvm::cl::value_desc("filename"), - llvm::cl::init("foo.mlir")); - llvm::cl::opt output_file( - "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"), - llvm::cl::init("foo.bin")); - llvm::cl::opt architecture( - "arch", llvm::cl::desc("target architecture (e.g. sm_50)"), - llvm::cl::init("sm_50")); - llvm::cl::list tile_sizes( - "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, - llvm::cl::CommaSeparated); - llvm::cl::list unroll_factors( - "unroll_factors", - llvm::cl::desc("factors to unroll by, separated by commas"), - llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); - - tensorflow::InitMlir y(&argc, &argv); - mlir::registerPassManagerCLOptions(); - llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n"); - - auto status = tensorflow::kernel_gen::Run( - input_file, output_file, architecture, tile_sizes, unroll_factors); - if (!status.ok()) { - LOG(ERROR) << status; - return 1; - } - return 0; -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc index a62a4136b4e..e0ad2349e89 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc @@ -115,9 +115,9 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, mlir::MLIRContext context; TF_ASSIGN_OR_RETURN( mlir::OwningModuleRef module, - GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false, - architectures, tile_sizes, unroll_factors, - embed_memref_prints, /*generate_fatbin=*/true, + GenerateKernelForTfCode(context, tf_code, architectures, tile_sizes, + unroll_factors, embed_memref_prints, + /*generate_fatbin=*/true, /*print_ptx=*/print_ptx)); // Get binary. TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module)); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 4fa6f025cc3..a1648745e44 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -77,20 +77,16 @@ cc_library( "embed_memref_prints.cc", "embed_tf_framework_pass.cc", "gpu_kernel_to_blob_pass.cc", - "materialize_broadcasts_pass.cc", "parallel_loops_to_sequential.cc", "same_shape_propagation.cc", "shape_to_descriptors_pass.cc", "tensorflow_abi_knowledge_propagation.cc", "tf_kernel_to_llvm_pass.cc", - "unfuse_batch_norm_pass.cc", ], hdrs = ["passes.h"], copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ "@llvm-project//mlir:Affine", - "//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep - "//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc deleted file mode 100644 index e0c21f0b2e4..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/materialize_broadcasts_pass.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" - -namespace mlir { -namespace kernel_gen { -namespace transforms { -namespace { - -#define GEN_PASS_CLASSES -#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" - -struct MaterializeBroadcastsPass - : public MaterializeBroadcastsPassBase { - void runOnFunction() override { - mlir::ConversionTarget conversionTarget(getContext()); - mlir::OwningRewritePatternList conversionPatterns; - - // Consider the mhlo dialect legal for tests. - conversionTarget.addLegalDialect(); - // The conversion uses helpers from the Standard dialect. - conversionTarget.addLegalDialect(); - - mlir::mhlo::SetupMaterializeBroadcastsLegality(&getContext(), - &conversionTarget); - mlir::mhlo::PopulateMaterializeBroadcastsPatterns(&getContext(), - &conversionPatterns); - - if (failed(applyPartialConversion(getFunction(), conversionTarget, - std::move(conversionPatterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr CreateMaterializeBroadcastsPass() { - return std::make_unique(); -} - -} // namespace transforms -} // namespace kernel_gen -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index f5169a16fac..98d831479f8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -62,9 +62,6 @@ std::unique_ptr> CreateHloBufferizePass(); // buffers. std::unique_ptr> CreateFinalBufferizePass(); -// Pass to materialize broadcasts. -std::unique_ptr CreateMaterializeBroadcastsPass(); - // Pass to convert scf::ParallelOp to scf::ForOp. std::unique_ptr CreateParallelLoopsToSequential(); @@ -74,9 +71,6 @@ std::unique_ptr> CreateGpuKernelToBlobPass( ArrayRef architectures = {}, bool generate_fatbin = true, bool print_ptx = false); -// Pass to unfuse batch norm. -std::unique_ptr CreateUnfuseBatchNormPass(); - // Pass to propagate tensorflow runtime ABI knowledge across kernel boundaries. std::unique_ptr CreatePropagateTfAbiKnowledgeToKernels(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index 2ec9bb3d3a6..abc1cb6ab06 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -61,16 +61,6 @@ def FinalBufferizePass : Pass<"final-bufferize", "ModuleOp"> { let constructor = "transforms::CreateFinalBufferizePass()"; } -def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> { - let summary = "Pass to materialize broadcasts"; - let constructor = "transforms::CreateMaterializeBroadcastsPass()"; -} - -def UnfuseBatchNormPass : FunctionPass<"unfuse-batch-norm"> { - let summary = "Pass to unfuse batch norm"; - let constructor = "transforms::CreateUnfuseBatchNormPass()"; -} - def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> { let summary = "Pass to annotate GPU Module with its PTX"; let options = [ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc deleted file mode 100644 index 5c347f471b1..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/unfuse_batch_norm_pass.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" - -namespace mlir { -namespace kernel_gen { -namespace transforms { -namespace { - -#define GEN_PASS_CLASSES -#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" - -struct UnfuseBatchNormPass - : public UnfuseBatchNormPassBase { - void runOnFunction() override { - mlir::OwningRewritePatternList patterns; - mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); - mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - -} // namespace - -std::unique_ptr CreateUnfuseBatchNormPass() { - return std::make_unique(); -} - -} // namespace transforms -} // namespace kernel_gen -} // namespace mlir diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index caca0208675..52d29b7bc93 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -261,7 +261,6 @@ tf_cuda_cc_test( gen_kernel_library( name = "abs", - generate_unranked = True, tile_size = "256", types = [ "f16", @@ -275,7 +274,6 @@ gen_kernel_library( gen_kernel_library( name = "conj", - generate_unranked = True, tile_size = "256", types = [ "f32", @@ -286,7 +284,6 @@ gen_kernel_library( gen_kernel_library( name = "imag", - generate_unranked = True, tile_size = "256", types = [ "f32", @@ -296,7 +293,6 @@ gen_kernel_library( gen_kernel_library( name = "invert", - generate_unranked = True, tile_size = "256", types = [ "i8", @@ -309,8 +305,6 @@ gen_kernel_library( gen_kernel_library( name = "is_inf", - generate_ranked = False, - generate_unranked = True, tile_size = "256", types = [ "f16", @@ -322,14 +316,12 @@ gen_kernel_library( gen_kernel_library( name = "logical_not", - generate_unranked = True, tile_size = "256", types = ["i1"], ) gen_kernel_library( name = "real", - generate_unranked = True, tile_size = "256", types = [ "f32", @@ -339,7 +331,6 @@ gen_kernel_library( gen_kernel_library( name = "sign", - generate_unranked = True, tile_size = "256", types = [ # TODO(b/162577610): Add bf16, c64 and c128. @@ -354,8 +345,6 @@ gen_kernel_library( gen_kernel_library( name = "add_v2", - generate_ranked = False, - generate_unranked = True, tile_size = "256,1,1", types = [ "f16", @@ -371,8 +360,6 @@ gen_kernel_library( [ gen_kernel_library( name = name, - generate_ranked = False, - generate_unranked = True, tile_size = "256,1,1", types = [ "i8", @@ -401,8 +388,6 @@ gen_kernel_library( [ gen_kernel_library( name = name, - generate_ranked = False, - generate_unranked = True, tile_size = "256,1,1", types = [ "i1", @@ -419,8 +404,6 @@ gen_kernel_library( [ gen_kernel_library( name = name, - generate_ranked = False, - generate_unranked = True, tile_size = "256,1,1", types = [ "f16", @@ -444,8 +427,6 @@ gen_kernel_library( [ gen_kernel_library( name = name, - generate_ranked = False, - generate_unranked = True, tile_size = "256,1,1", types = [ "f16", @@ -470,8 +451,6 @@ gen_kernel_library( [ gen_kernel_library( name = name, - generate_ranked = False, - generate_unranked = True, tile_size = "256,1,1", types = [ "f16", @@ -494,7 +473,6 @@ gen_kernel_library( [ gen_kernel_library( name = name, - generate_unranked = True, tile_size = "256", types = [ "f16", @@ -516,7 +494,6 @@ gen_kernel_library( [ gen_kernel_library( name = name, - generate_unranked = True, tile_size = "256", types = [ "f16", @@ -541,7 +518,6 @@ gen_kernel_library( [ gen_kernel_library( name = name, - generate_unranked = True, tile_size = "256", types = [ "f16", diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl index 5db18a55642..2605a4c9670 100644 --- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl +++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl @@ -31,168 +31,6 @@ GpuBinaryInfo = provider( fields = ["gpu_bins"], ) -def _gen_kernel_gpu_bin_impl(ctx): - name = ctx.attr.name - tile_sizes = ctx.attr.tile_size.replace("x", ",") - cmd_args = [] - if ctx.attr.unroll_factors: - cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors) - - if ctx.attr.extra_args: - cmd_args.extend(ctx.attr.extra_args) - - gpu_bins = [] - for arch in ctx.attr.gpu_archs: - # TODO(b/170283783): 'compute_' should generate both SASS and PTX. - arch = arch.replace("compute_", "sm_") - filename = "%s.%s.bin" % (name, arch) - gpu_bin = ctx.actions.declare_file(filename) - ctx.actions.run( - inputs = [ctx.file.mlir_op, ctx.file._tfso], - outputs = [gpu_bin], - executable = ctx.executable._tool, - arguments = cmd_args + [ - "--tile_sizes=%s" % tile_sizes, - "--arch=%s" % arch, - "--input=%s" % ctx.file.mlir_op.path, - "--output=%s" % gpu_bin.path, - ], - mnemonic = "compile", - ) - gpu_bins.append(gpu_bin) - return [GpuBinaryInfo(gpu_bins = gpu_bins)] - -_gen_kernel_gpu_bin_rule = rule( - attrs = { - "mlir_op": attr.label(mandatory = True, allow_single_file = True), - "tile_size": attr.string(mandatory = True), - "unroll_factors": attr.string(), - "gpu_archs": attr.string_list(mandatory = True), - "extra_args": attr.string_list(), - "_tfso": attr.label( - default = Label("//tensorflow:libtensorflow_framework.so.2"), - cfg = "host", - allow_single_file = True, - ), - "_tool": attr.label( - executable = True, - default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_gpu_binary"), - cfg = "host", - ), - }, - output_to_genfiles = True, - implementation = _gen_kernel_gpu_bin_impl, -) - -def _gen_kernel_image_hdr_impl_cuda(ctx): - images = [] - for cubin in ctx.attr.input[GpuBinaryInfo].gpu_bins: - arch = cubin.path.split(".")[-2] - images.append("--image=profile=%s,file=%s" % (arch, cubin.path)) - - # Generate fatbin file from all cubins. - fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name) - ctx.actions.run( - outputs = [fatbin], - inputs = ctx.attr.input[GpuBinaryInfo].gpu_bins, - executable = _lookup_file(ctx.attr._gpu_root, "bin/fatbinary"), - arguments = [ - "--64", - "--cmdline=--compile-only", - "--link", - "--compress-all", - "--create=%s" % fatbin.path, - ] + images, - mnemonic = "fatbinary", - ) - - bin2c = _lookup_file(ctx.attr._gpu_root, "bin/bin2c") - ctx.actions.run_shell( - outputs = [ctx.outputs.out], - inputs = [fatbin], - tools = [bin2c], - command = "%s --static --const --type=char --name=%s %s 1> %s" % - (bin2c.path, ctx.attr.symbol, fatbin.path, ctx.outputs.out.path), - mnemonic = "bin2c", - ) - -def _gen_kernel_image_hdr_impl_rocm(ctx): - hsaco_files = [] - hsaco_targets = [] - - # Add a dummy host target triple...clang-offload-bundler requires 1 and only 1 host target triple - hsaco_files.append("/dev/null") - hsaco_targets.append("host-x86_64-unknown-linux") - - hsacos = ctx.attr.input[GpuBinaryInfo].gpu_bins - for hsaco in hsacos: - gfx_arch = hsaco.path.split(".")[-2] - hsaco_files.append(hsaco.path) - hsaco_targets.append("hip-amdgcn-amd-amdhsa-%s" % gfx_arch) - - # Generate fatbin file from all hsacos. - fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name) - ctx.actions.run( - outputs = [fatbin], - inputs = hsacos, - executable = _lookup_file(ctx.attr._gpu_root, "bin/clang-offload-bundler"), - arguments = [ - "--inputs=%s" % ",".join(hsaco_files), - "--targets=%s" % ",".join(hsaco_targets), - "--type=o", - "--outputs=%s" % fatbin.path, - ], - mnemonic = "fatbinary", - ) - - ctx.actions.run_shell( - outputs = [ctx.outputs.out], - inputs = [fatbin], - command = ( - ("hex=`hexdump -v -e \'/1 \"0x%%02x, \"\' %s` && " + - "len=`echo $hex | wc -c` && " + - "echo 'static const unsigned char %s['$len' + 1] = {' > %s && " + - "echo $hex | cat >> %s && " + - "echo '};' >> %s") % ( - fatbin.path, - ctx.attr.symbol, - ctx.outputs.out.path, - ctx.outputs.out.path, - ctx.outputs.out.path, - ) - ), - ) - -_gen_kernel_image_hdr_rule = rule( - implementation = _gen_kernel_image_hdr_impl_rocm if rocm_is_configured() else _gen_kernel_image_hdr_impl_cuda, - output_to_genfiles = True, - attrs = { - "input": attr.label(mandatory = True, providers = [GpuBinaryInfo]), - "out": attr.output(mandatory = True), - "symbol": attr.string(mandatory = True), - "_gpu_root": attr.label( - default = Label("@local_config_rocm//rocm:rocm_root") if rocm_is_configured() else Label("@local_config_cuda//cuda:cuda_root"), - ), - }, -) - -def _gen_kernel_image_hdr(name, mlir_op, gpu_archs, tile_size, unroll_factors = None, extra_args = []): - """Generates a C header with fatbin data from a Tensorflow op.""" - _gen_kernel_gpu_bin_rule( - name = name + "_cubin", - mlir_op = mlir_op, - tile_size = tile_size, - unroll_factors = unroll_factors, - gpu_archs = gpu_archs, - extra_args = extra_args, - ) - _gen_kernel_image_hdr_rule( - name = name, - input = ":" + name + "_cubin", - out = "%s.h" % name, - symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""), - ) - type_to_mlir = { "c64": "complex", "c128": "complex", @@ -204,18 +42,12 @@ def _gen_mlir_op_impl(ctx): if mlir_type in type_to_mlir: mlir_type = type_to_mlir[mlir_type] - # In order to generate a ranked kernel we change *xelem_type to ?xelem_type - # and remove element type from the entry function name. - convert_to_ranked = "" - if ctx.attr.unranked == False: - convert_to_ranked = "sed s/*x/?x/g | sed s/_elem_type//g |" cmd = ctx.actions.run_shell( inputs = [ctx.file.template], outputs = [ctx.outputs.out], command = ( - ("cat %s | %s sed 's/_elem_type/_%s/g' | sed 's/elem_type/%s/g' > %s") % ( + ("cat %s | sed 's/_elem_type/_%s/g' | sed 's/elem_type/%s/g' > %s") % ( ctx.file.template.path, - convert_to_ranked, ctx.attr.type, mlir_type, ctx.outputs.out.path, @@ -244,40 +76,6 @@ def _gen_mlir_op(name, type, unranked): unranked = unranked, ) -def gen_ranked_kernel_library(name, types, tile_size, tags = [], unroll_factors = None, extra_args = []): - """ Generate a library with kernels for a specific tensorflow op. - - Args: - name: The name of the tensorflow op. - types: The types ("f16", "f32", "f64") for which a kernel should be generated. - tile_size: The tiling specification, e.g. "16x16". - unroll_factors: The unrolling specification, e.g. "4,4" - tags: The tags which should be added to the library. - extra_args: Extra arguments to pass to the generator tool. - """ - - if cuda_gpu_architectures() or rocm_gpu_architectures(): - for type in types: - _gen_mlir_op( - name = name, - type = type, - unranked = False, - ) - _gen_kernel_image_hdr( - name = "{name}_{type}_kernel".format(name = name, type = type), - mlir_op = "{name}_{type}.mlir".format(name = name, type = type), - gpu_archs = rocm_gpu_architectures() if rocm_is_configured() else cuda_gpu_architectures(), - tile_size = tile_size, - unroll_factors = unroll_factors, - extra_args = extra_args, - ) - - native.cc_library( - name = name + "_kernels", - hdrs = if_gpu_is_configured([":{name}_{type}_kernel".format(name = name, type = type) for type in types]), - tags = tags, - ) - ################################################################################ # Unranked kernels build rules. ################################################################################ @@ -410,22 +208,12 @@ def gen_unranked_kernel_library(name, types, tile_size, tags = [], unroll_factor tags = tags, ) -def gen_kernel_library(name, types, tile_size, tags = [], unroll_factors = None, extra_args = [], generate_ranked = True, generate_unranked = False): - if (generate_ranked): - gen_ranked_kernel_library( - name = name, - types = types, - tile_size = tile_size, - tags = tags, - unroll_factors = unroll_factors, - extra_args = extra_args, - ) - if (generate_unranked): - gen_unranked_kernel_library( - name = name + "_unranked", - types = types, - tile_size = tile_size, - tags = tags, - unroll_factors = unroll_factors, - extra_args = extra_args, - ) +def gen_kernel_library(name, types, tile_size, tags = [], unroll_factors = None, extra_args = []): + gen_unranked_kernel_library( + name = name + "_unranked", + types = types, + tile_size = tile_size, + tags = tags, + unroll_factors = unroll_factors, + extra_args = extra_args, + ) diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cc b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cc deleted file mode 100644 index 948a7c00437..00000000000 --- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_abs.cc +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/mlir_generated/abs_f16_kernel.h" -#include "tensorflow/core/kernels/mlir_generated/abs_f32_kernel.h" -#include "tensorflow/core/kernels/mlir_generated/abs_f64_kernel.h" -#include "tensorflow/core/kernels/mlir_generated/abs_i32_kernel.h" -#include "tensorflow/core/kernels/mlir_generated/abs_i64_kernel.h" -#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h" - -namespace tensorflow { -namespace { -GENERATE_OP_KERNEL_BASE(Abs); -} // namespace - -GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, F16, Eigen::half); -GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, F32, float); -GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, F64, double); -GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, I32, int32); -GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, I64, int64); -} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cc b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cc deleted file mode 100644 index c5fbb155923..00000000000 --- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h" - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor.h" - -namespace tensorflow { -namespace { -Status CreateKernel(absl::string_view kernel_name, uint64_t num_args, - absl::string_view ptx, absl::Span cubin_data, - se::StreamExecutor* stream_exec, - std::unique_ptr& kernel_base) { - se::MultiKernelLoaderSpec loader_spec(num_args); - - if (!cubin_data.empty()) { - loader_spec.AddCudaCubinInMemory( - reinterpret_cast(cubin_data.data()), kernel_name); - } - - kernel_base.reset(new se::KernelBase(stream_exec)); - return stream_exec->GetKernel(loader_spec, kernel_base.get()); -} - -struct LaunchConfig { - se::BlockDim blockDim; - se::ThreadDim threadDim; -}; - -LaunchConfig GetLaunchConfiguration(std::vector tile_sizes, - std::vector unrolling_factors, - std::vector shape) { - LaunchConfig result; - // Ensure the vectors are length 3 and pad with ones. - tile_sizes.resize(3, 1); - unrolling_factors.resize(3, 1); - shape.resize(3, 1); - // The number of threads is given by the tiling size. - result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]); - // We know that the kernel was generated by mapping the three outer-most - // dimensions to x,y,z dimensions. So we only need to compute those. - std::vector block_dims(3); - for (int i = 0; i < 3; ++i) { - // Compute the number of grids. We use ceildiv here as we have to allocate - // an extra thread/block if the division is not even. The kernel contains - // code to handle the boundaries. - uint64 number_of_threads = Eigen::divup(shape[i], unrolling_factors[i]); - int number_of_grids = Eigen::divup(number_of_threads, tile_sizes[i]); - block_dims[i] = number_of_grids; - } - result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]); - return result; -} -} // namespace - -void MlirGeneratedUnaryOp::Compute(OpKernelContext* ctx) { - auto* stream = ctx->op_device_context()->stream(); - se::KernelBase* kernel; - { - absl::MutexLock l(&mu_); - if (!kernel_) { - OP_REQUIRES_OK(ctx, CreateKernel(name_, 10, "", cubin_data_, - stream->parent(), kernel_)); - } - kernel = kernel_.get(); - } - - const Tensor& inp = ctx->input(0); - Tensor* out = nullptr; - OP_REQUIRES_OK( - ctx, ctx->forward_input_or_allocate_output({0}, 0, inp.shape(), &out)); - - if (inp.NumElements() == 0) { - return; - } - - se::KernelArgsArray<10> args; - - args.add_device_memory_argument( - stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes())); - args.add_device_memory_argument( - stream_executor::DeviceMemoryBase(inp.data(), inp.TotalBytes())); - args.add_argument(0); - args.add_argument(inp.NumElements()); - args.add_argument(1); - - args.add_device_memory_argument( - stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes())); - args.add_device_memory_argument( - stream_executor::DeviceMemoryBase(out->data(), out->TotalBytes())); - args.add_argument(0); - args.add_argument(inp.NumElements()); - args.add_argument(1); - - // This has to be aligned with the configuration that was used when building - // the kernels. See the corresponding build rules in the `BUILD` file. - LaunchConfig config = GetLaunchConfiguration( - {256}, {4}, {static_cast(inp.NumElements())}); - OP_REQUIRES_OK(ctx, stream->parent()->Launch(stream, config.threadDim, - config.blockDim, *kernel, args)); -} - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h deleted file mode 100644 index 466bbead3a5..00000000000 --- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_H_ -#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_H_ - -#include -#include - -#include "absl/strings/ascii.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/platform/stream_executor.h" - -namespace tensorflow { -class MlirGeneratedUnaryOp : public OpKernel { - public: - MlirGeneratedUnaryOp(OpKernelConstruction* ctx, std::string name, - absl::Span cubin_data) - : OpKernel(ctx), name_(name), cubin_data_(cubin_data) {} - - void Compute(OpKernelContext* ctx) override; - - private: - std::string name_; - absl::Span cubin_data_; - std::unique_ptr kernel_; - absl::Mutex mu_; -}; - -#define GENERATE_OP_KERNEL_BASE(kernel_name) \ - class MlirGenerated##kernel_name##Op : public MlirGeneratedUnaryOp { \ - public: \ - MlirGenerated##kernel_name##Op(OpKernelConstruction* ctx, \ - absl::Span cubin_data) \ - : MlirGeneratedUnaryOp(ctx, #kernel_name "_kernel", cubin_data) {} \ - }; - -#define GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \ - class MlirGenerated##kernel_name##data_type##Op \ - : public MlirGenerated##kernel_name##Op { \ - public: \ - explicit MlirGenerated##kernel_name##data_type##Op( \ - OpKernelConstruction* ctx) \ - : MlirGenerated##kernel_name \ - ##Op(ctx, k##kernel_name##data_type##Kernel) {} \ - }; - -#define GENERATE_AND_REGISTER_UNARY_KERNEL(kernel_name, data_type, \ - native_data_type) \ - namespace { \ - GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \ - } \ - REGISTER_KERNEL_BUILDER(Name(#kernel_name) \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - MlirGenerated##kernel_name##data_type##Op); - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_CWISE_OP_GPU_BASE_H_ diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cc b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cc deleted file mode 100644 index a9cc0666b0b..00000000000 --- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_tanh.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h" -#include "tensorflow/core/kernels/mlir_generated/tanh_f16_kernel.h" -#include "tensorflow/core/kernels/mlir_generated/tanh_f32_kernel.h" -#include "tensorflow/core/kernels/mlir_generated/tanh_f64_kernel.h" - -namespace tensorflow { -namespace { -GENERATE_OP_KERNEL_BASE(Tanh); -} // namespace - -GENERATE_AND_REGISTER_UNARY_KERNEL(Tanh, F16, Eigen::half) -GENERATE_AND_REGISTER_UNARY_KERNEL(Tanh, F32, float) -GENERATE_AND_REGISTER_UNARY_KERNEL(Tanh, F64, double) -} // namespace tensorflow