diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index f839acd32a2..405471ab1e4 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -112,8 +112,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes", - "//tensorflow/compiler/mlir/tosa:tf_tosa_passes", - "//tensorflow/compiler/mlir/tosa:tfl_tosa_passes", + "//tensorflow/compiler/mlir/tosa:tf_passes", + "//tensorflow/compiler/mlir/tosa:tfl_passes", ], ) diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index cc423b81b6f..9032a409cc0 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -8,7 +8,7 @@ load("//third_party/mlir:tblgen.bzl", "gentbl") # TODO: Tighten visibility once targets are at the right granularity. package( - default_visibility = [":friends"], + default_visibility = [":internal"], licenses = ["notice"], # Apache 2.0 ) @@ -33,17 +33,13 @@ package_group( filegroup( name = "tosa_ops_td_files", srcs = [ - "@llvm-project//mlir:TdFiles", + "@llvm-project//mlir:TosaDialectTdFiles", ], - # TODO: Switch to pruned list of TD files once build file changes land. - # srcs = [ - # "@llvm-project//mlir:TosaDialectTdFiles", - # ], compatible_with = get_compatible_with_cloud(), ) gentbl( - name = "tosa_pass_inc_gen", + name = "tosa_passes_inc_gen", compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( @@ -58,6 +54,40 @@ gentbl( ], ) +cc_library( + name = "passes_header", + hdrs = [ + "transforms/passes.h", + "transforms/passes.h.inc", + ], + compatible_with = get_compatible_with_cloud(), + deps = ["@llvm-project//mlir:Pass"], +) + +cc_library( + name = "legalize_common", + srcs = [ + "transforms/legalize_common.cc", + "transforms/legalize_utils.cc", + ], + hdrs = [ + "transforms/legalize_common.h", + "transforms/legalize_utils.h", + ], + compatible_with = get_compatible_with_cloud(), + deps = [ + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:conv_grad_shape_utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TosaDialect", + ], + alwayslink = 1, +) + gentbl( name = "tosa_legalize_tf_inc_gen", compatible_with = get_compatible_with_cloud(), @@ -76,6 +106,36 @@ gentbl( ], ) +cc_library( + name = "tf_passes", + srcs = [ + "tf_passes.cc", + "transforms/fuse_bias_tf.cc", + "transforms/legalize_tf.cc", + "transforms/tf_legalize_patterns.inc", + ], + hdrs = [ + "tf_passes.h", + "transforms/passes.h", + ], + compatible_with = get_compatible_with_cloud(), + visibility = [":friends"], + deps = [ + ":legalize_common", + ":passes_header", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + gentbl( name = "tosa_legalize_tfl_inc_gen", compatible_with = get_compatible_with_cloud(), @@ -95,233 +155,31 @@ gentbl( ) cc_library( - name = "tosa_legalize_tf", - srcs = [ - "transforms/legalize_tf.cc", - "transforms/tf_legalize_patterns.inc", - ], - hdrs = [ - "transforms/legalize_common.h", - "transforms/legalize_utils.h", - "transforms/passes.h", - "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", - ], - compatible_with = get_compatible_with_cloud(), - deps = [ - ":tosa_legalize_tf_inc_gen", - ":tosa_pass_inc_gen", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_all_ops_inc_gen", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core/kernels:conv_grad_shape_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", - "@flatbuffers", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TosaDialect", - "@llvm-project//mlir:TransformUtils", - ], - alwayslink = 1, -) - -cc_library( - name = "tosa_legalize_tfl", + name = "tfl_passes", srcs = [ + "tfl_passes.cc", + "transforms/convert_tfl_uint8.cc", "transforms/legalize_tfl.cc", "transforms/tfl_legalize_patterns.inc", ], hdrs = [ - "transforms/legalize_common.h", - "transforms/legalize_utils.h", + "tfl_passes.h", "transforms/passes.h", - "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h", - "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], compatible_with = get_compatible_with_cloud(), + visibility = [":friends"], deps = [ - ":tosa_legalize_tfl_inc_gen", - ":tosa_pass_inc_gen", + ":legalize_common", + ":passes_header", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_inc_gen", - "//tensorflow/compiler/mlir/lite:validators", - "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", - "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core/kernels:conv_grad_shape_utils", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", - "@flatbuffers", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TosaDialect", - "@llvm-project//mlir:TransformUtils", - ], - alwayslink = 1, -) - -cc_library( - name = "tosa_legalize_common", - srcs = [ - "transforms/legalize_common.cc", - "transforms/legalize_utils.cc", - "transforms/tf_legalize_patterns.inc", - ], - hdrs = [ - "transforms/legalize_common.h", - "transforms/legalize_utils.h", - "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", - ], - compatible_with = get_compatible_with_cloud(), - deps = [ - "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_inc_gen", - "//tensorflow/compiler/mlir/lite:validators", - "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_all_ops_inc_gen", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core/kernels:conv_grad_shape_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", - "@flatbuffers", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TosaDialect", - "@llvm-project//mlir:TransformUtils", - ], - alwayslink = 1, -) - -cc_library( - name = "tosa_fuse_bias_tf", - srcs = [ - "transforms/fuse_bias_tf.cc", - ], - hdrs = [ - "transforms/passes.h", - "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", - ], - compatible_with = get_compatible_with_cloud(), - deps = [ - ":tosa_legalize_common", - ":tosa_pass_inc_gen", - "//tensorflow/compiler/mlir/tensorflow", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TosaDialect", - "@llvm-project//mlir:TransformUtils", - ], - alwayslink = 1, -) - -cc_library( - name = "tosa_convert_tfl_uint8", - srcs = [ - "transforms/convert_tfl_uint8.cc", - ], - hdrs = [ - "transforms/passes.h", - "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", - ], - compatible_with = get_compatible_with_cloud(), - deps = [ - ":tosa_legalize_common", - ":tosa_pass_inc_gen", - "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TosaDialect", - "@llvm-project//mlir:TransformUtils", - ], - alwayslink = 1, -) - -cc_library( - name = "tosa_pipelines", - srcs = [ - "tosa_passpipes.cc", - ], - hdrs = [ - "tosa_passpipes.h", - "transforms/passes.h", - "transforms/register_passes.h", - ], - compatible_with = get_compatible_with_cloud(), - deps = [ - ":tosa_pass_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:TosaDialect", - "@llvm-project//mlir:TransformUtils", - ], - alwayslink = 1, -) - -cc_library( - name = "tf_tosa_passes", - srcs = [ - "tf_tosa_pipeline.cc", - ], - hdrs = [ - ], - compatible_with = get_compatible_with_cloud(), - deps = [ - ":tosa_fuse_bias_tf", - ":tosa_legalize_common", - ":tosa_legalize_tf", - ":tosa_pipelines", - ], - alwayslink = 1, -) - -cc_library( - name = "tfl_tosa_passes", - srcs = [ - "tfl_tosa_pipeline.cc", - ], - hdrs = [ - ], - compatible_with = get_compatible_with_cloud(), - deps = [ - ":tosa_convert_tfl_uint8", - ":tosa_legalize_common", - ":tosa_legalize_tfl", - ":tosa_pipelines", + "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/tosa/tf_passes.cc b/tensorflow/compiler/mlir/tosa/tf_passes.cc new file mode 100644 index 00000000000..fadf7e54580 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tf_passes.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tosa/tf_passes.h" + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" + +namespace mlir { +namespace tosa { + +void createTFtoTOSALegalizationPipeline( + OpPassManager& pm, const TOSATFLegalizationPipelineOptions& opts) { + //---------------------------------------------------------------------------- + // Prepare TFL module for conversion + //---------------------------------------------------------------------------- + // Inline all functions into main and then delete the functions themselves. + pm.addPass(mlir::createInlinerPass()); + + // Now that there is only one function, run some MLIR passes on it. + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + + pm.addPass(mlir::createLoopFusionPass()); + pm.addPass(mlir::createMemRefDataFlowOptPass()); + + //---------------------------------------------------------------------------- + // Perform main conversion. + // Now that there is only one function, run some MLIR passes on it. + //---------------------------------------------------------------------------- + pm.addPass(mlir::tosa::createFuseBiasTFPass()); + pm.addPass(mlir::tosa::createLegalizeTFPass()); + + //---------------------------------------------------------------------------- + // Post conversion cleanup. + //---------------------------------------------------------------------------- + pm.addPass(mlir::tosa::createTosaMakeBroadcastablePass()); + // Inline the call/return basic blocks within TOSA control flow ops. + pm.addPass(mlir::createInlinerPass()); + // Clean up with DCE. + pm.addPass(mlir::createSymbolDCEPass()); +} + +static mlir::PassPipelineRegistration + tf_tosa_pipeline("tf-to-tosa-pipeline", + "TensorFlow to TOSA legalization pipeline", + createTFtoTOSALegalizationPipeline); + +} // namespace tosa +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/tosa_passpipes.h b/tensorflow/compiler/mlir/tosa/tf_passes.h similarity index 65% rename from tensorflow/compiler/mlir/tosa/tosa_passpipes.h rename to tensorflow/compiler/mlir/tosa/tf_passes.h index eee7e634a12..18d11cde4d3 100644 --- a/tensorflow/compiler/mlir/tosa/tosa_passpipes.h +++ b/tensorflow/compiler/mlir/tosa/tf_passes.h @@ -16,28 +16,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H #define TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H -#include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/PassManager.h" -#include "llvm/ADT/Optional.h" -#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project namespace mlir { - namespace tosa { -void addPreOptMlirPasses(mlir::OpPassManager& pm); - -void addPostOptMlirPasses(mlir::OpPassManager& pm); +struct TOSATFLegalizationPipelineOptions + : public PassPipelineOptions {}; +// Legalizes TF dialect(s) to Tosa. void createTFtoTOSALegalizationPipeline( - OpPassManager& pm, const TOSALegalizationPipelineOptions& opts); - -void createTFLtoTOSALegalizationPipeline( - OpPassManager& pm, const TOSALegalizationPipelineOptions& opts); + OpPassManager& pm, const TOSATFLegalizationPipelineOptions& opts); } // namespace tosa - } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H diff --git a/tensorflow/compiler/mlir/tosa/tosa_passpipes.cc b/tensorflow/compiler/mlir/tosa/tfl_passes.cc similarity index 56% rename from tensorflow/compiler/mlir/tosa/tosa_passpipes.cc rename to tensorflow/compiler/mlir/tosa/tfl_passes.cc index 1bad41522f3..25d9041a508 100644 --- a/tensorflow/compiler/mlir/tosa/tosa_passpipes.cc +++ b/tensorflow/compiler/mlir/tosa/tfl_passes.cc @@ -13,23 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h" +#include "tensorflow/compiler/mlir/tosa/tfl_passes.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Passes.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" namespace mlir { - namespace tosa { -void addPreOptMlirPasses(mlir::OpPassManager& pm) { +void createTFLtoTOSALegalizationPipeline( + OpPassManager& pm, const TOSATFLLegalizationPipelineOptions& opts) { + //---------------------------------------------------------------------------- + // Prepare TFL module for conversion + //---------------------------------------------------------------------------- // Inline all functions into main and then delete the functions themselves. pm.addPass(mlir::createInlinerPass()); @@ -39,9 +36,16 @@ void addPreOptMlirPasses(mlir::OpPassManager& pm) { pm.addPass(mlir::createLoopFusionPass()); pm.addPass(mlir::createMemRefDataFlowOptPass()); -} -void addPostOptMlirPasses(mlir::OpPassManager& pm) { + //---------------------------------------------------------------------------- + // Perform main conversion. + //---------------------------------------------------------------------------- + pm.addPass(mlir::tosa::createConvertTFLUint8Pass()); + pm.addPass(mlir::tosa::createLegalizeTFLPass()); + + //---------------------------------------------------------------------------- + // Post conversion cleanup. + //---------------------------------------------------------------------------- pm.addPass(mlir::tosa::createTosaMakeBroadcastablePass()); // Inline the call/return basic blocks within TOSA control flow ops. pm.addPass(mlir::createInlinerPass()); @@ -49,26 +53,10 @@ void addPostOptMlirPasses(mlir::OpPassManager& pm) { pm.addPass(mlir::createSymbolDCEPass()); } -void createTFtoTOSALegalizationPipeline( - OpPassManager& pm, const TOSALegalizationPipelineOptions& opts) { - addPreOptMlirPasses(pm); - - pm.addPass(mlir::tosa::createFuseBiasTFPass()); - pm.addPass(mlir::tosa::createLegalizeTFPass()); - - addPostOptMlirPasses(pm); -} - -void createTFLtoTOSALegalizationPipeline( - OpPassManager& pm, const TOSALegalizationPipelineOptions& opts) { - addPreOptMlirPasses(pm); - - pm.addPass(mlir::tosa::createConvertTFLUint8Pass()); - pm.addPass(mlir::tosa::createLegalizeTFLPass()); - - addPostOptMlirPasses(pm); -} +static mlir::PassPipelineRegistration + tfl_tosa_pipeline("tfl-to-tosa-pipeline", + "TensorFlow Lite to TOSA legalization pipeline", + createTFLtoTOSALegalizationPipeline); } // namespace tosa - } // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/tf_tosa_pipeline.cc b/tensorflow/compiler/mlir/tosa/tfl_passes.h similarity index 57% rename from tensorflow/compiler/mlir/tosa/tf_tosa_pipeline.cc rename to tensorflow/compiler/mlir/tosa/tfl_passes.h index e8d1aa73478..255418ae443 100644 --- a/tensorflow/compiler/mlir/tosa/tf_tosa_pipeline.cc +++ b/tensorflow/compiler/mlir/tosa/tfl_passes.h @@ -13,17 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h" +#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project namespace mlir { - namespace tosa { -static mlir::PassPipelineRegistration - tf_tosa_pipeline("tf-to-tosa-pipeline", - "TensorFlow to TOSA legalization pipeline", - createTFtoTOSALegalizationPipeline); +struct TOSATFLLegalizationPipelineOptions + : public PassPipelineOptions {}; + +// Legalizes TFL (TensorFlow lite) dialect(s) to Tosa. +void createTFLtoTOSALegalizationPipeline( + OpPassManager& pm, const TOSATFLLegalizationPipelineOptions& opts); } // namespace tosa - } // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tosa/tfl_tosa_pipeline.cc b/tensorflow/compiler/mlir/tosa/tfl_tosa_pipeline.cc deleted file mode 100644 index 8552a68101a..00000000000 --- a/tensorflow/compiler/mlir/tosa/tfl_tosa_pipeline.cc +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h" - -namespace mlir { - -namespace tosa { - -static mlir::PassPipelineRegistration - tfl_tosa_pipeline("tfl-to-tosa-pipeline", - "TensorFlow Lite to TOSA legalization pipeline", - createTFLtoTOSALegalizationPipeline); - -} // namespace tosa - -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc index 8a0e36dd941..08ee3c29ed4 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc @@ -29,25 +29,14 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" diff --git a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc index 74382ded178..058ba48e2c7 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc @@ -21,23 +21,10 @@ limitations under the License. #include #include -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 084efbf077a..9f987cad3b2 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -14,18 +14,26 @@ limitations under the License. ==============================================================================*/ // This file contains legalizations common to mapping both TensorFlow and -// TensorFlow Lite to TOSA. +// TensorFlow Lite to TOSA. It operates generically on ops and does not have +// a hard reference on either dialect. // // Conversion functions return llvm::None on a legalization failure or a // legalized value on success. Callers must check for presence of an // llvm::Optional value after each call. #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" + #include #include #include #include #include + +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 06016bbfb3b..d5ef518f176 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -16,39 +16,17 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H #define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + // This file contains legalizations common to mapping both TensorFlow and // TensorFlow Lite to TOSA. // -// Conversion functions return nullptr on a lowerization failure or a lowered -// operator on success. Callers must check and return a LogicalResult failure -// on nullptr. +// Conversion functions return None on a failure or result value on success. +// Callers must check and return a LogicalResult failure on nullptr. // // For these functions, the framework-specific operands/attributes/defaults // are already extracted and placed in a common form for lowering. -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/Support/FormatVariadic.h" namespace mlir { namespace tosa { diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index e24253b420c..1219e14eed5 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -21,30 +21,9 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringSwitch.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index 2ae339dc6d4..4e51bd795b7 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -23,32 +23,9 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/QuantTypes.h" -#include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSwitch.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" @@ -2996,5 +2973,6 @@ std::unique_ptr> createLegalizeTFLPass() { static PassRegistration pass( PASS_NAME, "Legalize from TensorFlow Lite to TOSA dialect"); + } // namespace tosa } // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index 5bae8eccf35..7280d4c23de 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -15,13 +15,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" // Implements legalization and post-legalization optimization helper functions namespace mlir { - namespace tosa { // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index 69671a6a7a5..f18e5733b8b 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -22,24 +22,10 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" #include "tensorflow/core/util/padding.h" diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h index f9449080ec0..69d4e923d20 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/passes.h +++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h @@ -18,15 +18,11 @@ limitations under the License. #include -#include "mlir/Pass/Pass.h" +#include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { - namespace tosa { -struct TOSALegalizationPipelineOptions - : public PassPipelineOptions {}; - std::unique_ptr> createLegalizeTFPass(); std::unique_ptr> createFuseBiasTFPass(); std::unique_ptr> createLegalizeTFLPass(); @@ -36,7 +32,6 @@ std::unique_ptr> createConvertTFLUint8Pass(); #include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" } // namespace tosa - } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H diff --git a/tensorflow/compiler/mlir/tosa/transforms/register_passes.h b/tensorflow/compiler/mlir/tosa/transforms/register_passes.h deleted file mode 100644 index 7d13205a42f..00000000000 --- a/tensorflow/compiler/mlir/tosa/transforms/register_passes.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H -#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H - -#include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Pass/Pass.h" -#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" - -namespace mlir { -namespace tosa { - -inline void registerAllTosaPasses() { - registerLegalizeTosaPasses(); - registerTosaOptPasses(); -} - -} // namespace tosa -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H