NFC: Simplify Tosa build and layering.

* Eliminate some intermediate source files, collapsing into two user facing libraries.
* Fix layering of dependencies so user libraries are self contained.
* Take a pass through and eliminate redundant includes.
* Separate TF and TFL passes as the latter can be built with much fewer dependencies (and they will diverge with respect to some of the things that were common).

PiperOrigin-RevId: 347113596
Change-Id: Ie5d0016750d7be49cd52ee06d349d798b177c217
This commit is contained in:
Stella Laurenzo 2020-12-11 18:29:41 -08:00 committed by TensorFlower Gardener
parent 4b4893e25b
commit 14a0a90f9b
17 changed files with 222 additions and 476 deletions

View File

@ -112,8 +112,8 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
"//tensorflow/compiler/mlir/tosa:tf_tosa_passes",
"//tensorflow/compiler/mlir/tosa:tfl_tosa_passes",
"//tensorflow/compiler/mlir/tosa:tf_passes",
"//tensorflow/compiler/mlir/tosa:tfl_passes",
],
)

View File

@ -8,7 +8,7 @@ load("//third_party/mlir:tblgen.bzl", "gentbl")
# TODO: Tighten visibility once targets are at the right granularity.
package(
default_visibility = [":friends"],
default_visibility = [":internal"],
licenses = ["notice"], # Apache 2.0
)
@ -33,17 +33,13 @@ package_group(
filegroup(
name = "tosa_ops_td_files",
srcs = [
"@llvm-project//mlir:TdFiles",
"@llvm-project//mlir:TosaDialectTdFiles",
],
# TODO: Switch to pruned list of TD files once build file changes land.
# srcs = [
# "@llvm-project//mlir:TosaDialectTdFiles",
# ],
compatible_with = get_compatible_with_cloud(),
)
gentbl(
name = "tosa_pass_inc_gen",
name = "tosa_passes_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
(
@ -58,6 +54,40 @@ gentbl(
],
)
cc_library(
name = "passes_header",
hdrs = [
"transforms/passes.h",
"transforms/passes.h.inc",
],
compatible_with = get_compatible_with_cloud(),
deps = ["@llvm-project//mlir:Pass"],
)
cc_library(
name = "legalize_common",
srcs = [
"transforms/legalize_common.cc",
"transforms/legalize_utils.cc",
],
hdrs = [
"transforms/legalize_common.h",
"transforms/legalize_utils.h",
],
compatible_with = get_compatible_with_cloud(),
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
],
alwayslink = 1,
)
gentbl(
name = "tosa_legalize_tf_inc_gen",
compatible_with = get_compatible_with_cloud(),
@ -76,6 +106,36 @@ gentbl(
],
)
cc_library(
name = "tf_passes",
srcs = [
"tf_passes.cc",
"transforms/fuse_bias_tf.cc",
"transforms/legalize_tf.cc",
"transforms/tf_legalize_patterns.inc",
],
hdrs = [
"tf_passes.h",
"transforms/passes.h",
],
compatible_with = get_compatible_with_cloud(),
visibility = [":friends"],
deps = [
":legalize_common",
":passes_header",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
gentbl(
name = "tosa_legalize_tfl_inc_gen",
compatible_with = get_compatible_with_cloud(),
@ -95,233 +155,31 @@ gentbl(
)
cc_library(
name = "tosa_legalize_tf",
srcs = [
"transforms/legalize_tf.cc",
"transforms/tf_legalize_patterns.inc",
],
hdrs = [
"transforms/legalize_common.h",
"transforms/legalize_utils.h",
"transforms/passes.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
compatible_with = get_compatible_with_cloud(),
deps = [
":tosa_legalize_tf_inc_gen",
":tosa_pass_inc_gen",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_all_ops_inc_gen",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@flatbuffers",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "tosa_legalize_tfl",
name = "tfl_passes",
srcs = [
"tfl_passes.cc",
"transforms/convert_tfl_uint8.cc",
"transforms/legalize_tfl.cc",
"transforms/tfl_legalize_patterns.inc",
],
hdrs = [
"transforms/legalize_common.h",
"transforms/legalize_utils.h",
"tfl_passes.h",
"transforms/passes.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
compatible_with = get_compatible_with_cloud(),
visibility = [":friends"],
deps = [
":tosa_legalize_tfl_inc_gen",
":tosa_pass_inc_gen",
":legalize_common",
":passes_header",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_inc_gen",
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@flatbuffers",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "tosa_legalize_common",
srcs = [
"transforms/legalize_common.cc",
"transforms/legalize_utils.cc",
"transforms/tf_legalize_patterns.inc",
],
hdrs = [
"transforms/legalize_common.h",
"transforms/legalize_utils.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
compatible_with = get_compatible_with_cloud(),
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_inc_gen",
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_all_ops_inc_gen",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@flatbuffers",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "tosa_fuse_bias_tf",
srcs = [
"transforms/fuse_bias_tf.cc",
],
hdrs = [
"transforms/passes.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
compatible_with = get_compatible_with_cloud(),
deps = [
":tosa_legalize_common",
":tosa_pass_inc_gen",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "tosa_convert_tfl_uint8",
srcs = [
"transforms/convert_tfl_uint8.cc",
],
hdrs = [
"transforms/passes.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
compatible_with = get_compatible_with_cloud(),
deps = [
":tosa_legalize_common",
":tosa_pass_inc_gen",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "tosa_pipelines",
srcs = [
"tosa_passpipes.cc",
],
hdrs = [
"tosa_passpipes.h",
"transforms/passes.h",
"transforms/register_passes.h",
],
compatible_with = get_compatible_with_cloud(),
deps = [
":tosa_pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "tf_tosa_passes",
srcs = [
"tf_tosa_pipeline.cc",
],
hdrs = [
],
compatible_with = get_compatible_with_cloud(),
deps = [
":tosa_fuse_bias_tf",
":tosa_legalize_common",
":tosa_legalize_tf",
":tosa_pipelines",
],
alwayslink = 1,
)
cc_library(
name = "tfl_tosa_passes",
srcs = [
"tfl_tosa_pipeline.cc",
],
hdrs = [
],
compatible_with = get_compatible_with_cloud(),
deps = [
":tosa_convert_tfl_uint8",
":tosa_legalize_common",
":tosa_legalize_tfl",
":tosa_pipelines",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tosa/tf_passes.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
namespace mlir {
namespace tosa {
void createTFtoTOSALegalizationPipeline(
OpPassManager& pm, const TOSATFLegalizationPipelineOptions& opts) {
//----------------------------------------------------------------------------
// Prepare TFL module for conversion
//----------------------------------------------------------------------------
// Inline all functions into main and then delete the functions themselves.
pm.addPass(mlir::createInlinerPass());
// Now that there is only one function, run some MLIR passes on it.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createLoopFusionPass());
pm.addPass(mlir::createMemRefDataFlowOptPass());
//----------------------------------------------------------------------------
// Perform main conversion.
// Now that there is only one function, run some MLIR passes on it.
//----------------------------------------------------------------------------
pm.addPass(mlir::tosa::createFuseBiasTFPass());
pm.addPass(mlir::tosa::createLegalizeTFPass());
//----------------------------------------------------------------------------
// Post conversion cleanup.
//----------------------------------------------------------------------------
pm.addPass(mlir::tosa::createTosaMakeBroadcastablePass());
// Inline the call/return basic blocks within TOSA control flow ops.
pm.addPass(mlir::createInlinerPass());
// Clean up with DCE.
pm.addPass(mlir::createSymbolDCEPass());
}
static mlir::PassPipelineRegistration<TOSATFLegalizationPipelineOptions>
tf_tosa_pipeline("tf-to-tosa-pipeline",
"TensorFlow to TOSA legalization pipeline",
createTFtoTOSALegalizationPipeline);
} // namespace tosa
} // namespace mlir

View File

@ -16,28 +16,20 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H
#define TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/ADT/Optional.h"
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Pass/PassOptions.h" // from @llvm-project
namespace mlir {
namespace tosa {
void addPreOptMlirPasses(mlir::OpPassManager& pm);
void addPostOptMlirPasses(mlir::OpPassManager& pm);
struct TOSATFLegalizationPipelineOptions
: public PassPipelineOptions<TOSATFLegalizationPipelineOptions> {};
// Legalizes TF dialect(s) to Tosa.
void createTFtoTOSALegalizationPipeline(
OpPassManager& pm, const TOSALegalizationPipelineOptions& opts);
void createTFLtoTOSALegalizationPipeline(
OpPassManager& pm, const TOSALegalizationPipelineOptions& opts);
OpPassManager& pm, const TOSATFLegalizationPipelineOptions& opts);
} // namespace tosa
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TOSA_PASSES_H

View File

@ -13,23 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
#include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
namespace mlir {
namespace tosa {
void addPreOptMlirPasses(mlir::OpPassManager& pm) {
void createTFLtoTOSALegalizationPipeline(
OpPassManager& pm, const TOSATFLLegalizationPipelineOptions& opts) {
//----------------------------------------------------------------------------
// Prepare TFL module for conversion
//----------------------------------------------------------------------------
// Inline all functions into main and then delete the functions themselves.
pm.addPass(mlir::createInlinerPass());
@ -39,9 +36,16 @@ void addPreOptMlirPasses(mlir::OpPassManager& pm) {
pm.addPass(mlir::createLoopFusionPass());
pm.addPass(mlir::createMemRefDataFlowOptPass());
}
void addPostOptMlirPasses(mlir::OpPassManager& pm) {
//----------------------------------------------------------------------------
// Perform main conversion.
//----------------------------------------------------------------------------
pm.addPass(mlir::tosa::createConvertTFLUint8Pass());
pm.addPass(mlir::tosa::createLegalizeTFLPass());
//----------------------------------------------------------------------------
// Post conversion cleanup.
//----------------------------------------------------------------------------
pm.addPass(mlir::tosa::createTosaMakeBroadcastablePass());
// Inline the call/return basic blocks within TOSA control flow ops.
pm.addPass(mlir::createInlinerPass());
@ -49,26 +53,10 @@ void addPostOptMlirPasses(mlir::OpPassManager& pm) {
pm.addPass(mlir::createSymbolDCEPass());
}
void createTFtoTOSALegalizationPipeline(
OpPassManager& pm, const TOSALegalizationPipelineOptions& opts) {
addPreOptMlirPasses(pm);
pm.addPass(mlir::tosa::createFuseBiasTFPass());
pm.addPass(mlir::tosa::createLegalizeTFPass());
addPostOptMlirPasses(pm);
}
void createTFLtoTOSALegalizationPipeline(
OpPassManager& pm, const TOSALegalizationPipelineOptions& opts) {
addPreOptMlirPasses(pm);
pm.addPass(mlir::tosa::createConvertTFLUint8Pass());
pm.addPass(mlir::tosa::createLegalizeTFLPass());
addPostOptMlirPasses(pm);
}
static mlir::PassPipelineRegistration<TOSATFLLegalizationPipelineOptions>
tfl_tosa_pipeline("tfl-to-tosa-pipeline",
"TensorFlow Lite to TOSA legalization pipeline",
createTFLtoTOSALegalizationPipeline);
} // namespace tosa
} // namespace mlir

View File

@ -13,17 +13,23 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Pass/PassOptions.h" // from @llvm-project
namespace mlir {
namespace tosa {
static mlir::PassPipelineRegistration<TOSALegalizationPipelineOptions>
tf_tosa_pipeline("tf-to-tosa-pipeline",
"TensorFlow to TOSA legalization pipeline",
createTFtoTOSALegalizationPipeline);
struct TOSATFLLegalizationPipelineOptions
: public PassPipelineOptions<TOSATFLLegalizationPipelineOptions> {};
// Legalizes TFL (TensorFlow lite) dialect(s) to Tosa.
void createTFLtoTOSALegalizationPipeline(
OpPassManager& pm, const TOSATFLLegalizationPipelineOptions& opts);
} // namespace tosa
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_

View File

@ -1,29 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
namespace mlir {
namespace tosa {
static mlir::PassPipelineRegistration<TOSALegalizationPipelineOptions>
tfl_tosa_pipeline("tfl-to-tosa-pipeline",
"TensorFlow Lite to TOSA legalization pipeline",
createTFLtoTOSALegalizationPipeline);
} // namespace tosa
} // namespace mlir

View File

@ -29,25 +29,14 @@ limitations under the License.
#include <iterator>
#include <numeric>
#include "mlir/Dialect/Quant/FakeQuantSupport.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"

View File

@ -21,23 +21,10 @@ limitations under the License.
#include <iterator>
#include <numeric>
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"

View File

@ -14,18 +14,26 @@ limitations under the License.
==============================================================================*/
// This file contains legalizations common to mapping both TensorFlow and
// TensorFlow Lite to TOSA.
// TensorFlow Lite to TOSA. It operates generically on ops and does not have
// a hard reference on either dialect.
//
// Conversion functions return llvm::None on a legalization failure or a
// legalized value on success. Callers must check for presence of an
// llvm::Optional value after each call.
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
#include <climits>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
namespace mlir {

View File

@ -16,39 +16,17 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
// This file contains legalizations common to mapping both TensorFlow and
// TensorFlow Lite to TOSA.
//
// Conversion functions return nullptr on a lowerization failure or a lowered
// operator on success. Callers must check and return a LogicalResult failure
// on nullptr.
// Conversion functions return None on a failure or result value on success.
// Callers must check and return a LogicalResult failure on nullptr.
//
// For these functions, the framework-specific operands/attributes/defaults
// are already extracted and placed in a common form for lowering.
#include "mlir/Dialect/Quant/FakeQuantSupport.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {

View File

@ -21,30 +21,9 @@ limitations under the License.
#include <iterator>
#include <numeric>
#include "mlir/Dialect/Quant/FakeQuantSupport.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"

View File

@ -23,32 +23,9 @@ limitations under the License.
#include <numeric>
#include <unordered_set>
#include "mlir/Dialect/Quant/FakeQuantSupport.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
@ -2996,5 +2973,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass() {
static PassRegistration<LegalizeTFL> pass(
PASS_NAME, "Legalize from TensorFlow Lite to TOSA dialect");
} // namespace tosa
} // namespace mlir

View File

@ -15,13 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
// Implements legalization and post-legalization optimization helper functions
namespace mlir {
namespace tosa {
// Create a TOSA rescale op from TFLite scaling, zero points and rounding mode

View File

@ -22,24 +22,10 @@ limitations under the License.
#include <iterator>
#include <numeric>
#include "mlir/Dialect/Quant/FakeQuantSupport.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/util/padding.h"

View File

@ -18,15 +18,11 @@ limitations under the License.
#include <memory>
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {
namespace tosa {
struct TOSALegalizationPipelineOptions
: public PassPipelineOptions<TOSALegalizationPipelineOptions> {};
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass();
std::unique_ptr<OperationPass<FuncOp>> createFuseBiasTFPass();
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass();
@ -36,7 +32,6 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertTFLUint8Pass();
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc"
} // namespace tosa
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H

View File

@ -1,34 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H
#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
namespace mlir {
namespace tosa {
inline void registerAllTosaPasses() {
registerLegalizeTosaPasses();
registerTosaOptPasses();
}
} // namespace tosa
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_REGISTER_PASSES_H