Split out dialect hooks into separate targets

This allows not linking in the hooks too if, for example, one wants the TF dialect but not constant folding via fallback hook.

PiperOrigin-RevId: 347295194
Change-Id: Iaf5af9c4c0c0ed00e5cc91ecc39cc4043c5ca0b6
This commit is contained in:
Jacques Pienaar 2020-12-13 17:25:16 -08:00 committed by TensorFlower Gardener
parent 3727302f03
commit edc060801f
7 changed files with 36 additions and 22 deletions

View File

@ -109,7 +109,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_pass", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_pass",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes", "//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
"//tensorflow/compiler/mlir/tosa:tf_passes", "//tensorflow/compiler/mlir/tosa:tf_passes",

View File

@ -937,8 +937,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_hooks",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",

View File

@ -1106,7 +1106,6 @@ cc_library(
":mlir_roundtrip_flags", ":mlir_roundtrip_flags",
":tensorflow", ":tensorflow",
":tensorflow_attributes", ":tensorflow_attributes",
":tensorflow_passes",
":tensorflow_types", ":tensorflow_types",
":tf_saved_model_passes", ":tf_saved_model_passes",
":translate_utils", ":translate_utils",
@ -1450,27 +1449,21 @@ cc_library(
) )
cc_library( cc_library(
name = "tf_dialect_passes", name = "tf_constant_fallback_hook",
srcs = [ srcs = [
"transforms/constant_fold.cc", "transforms/constant_fold.cc",
"transforms/decode_attributes_hook.cc",
], ],
hdrs = [ hdrs = [
"transforms/constant_fold.h", "transforms/constant_fold.h",
], ],
deps = [ deps = [
":convert_tensor",
":decode_constant_pass", ":decode_constant_pass",
":eval_util", ":eval_util",
":tensorflow", ":tensorflow",
":tensorflow_traits", ":tensorflow_traits",
":tensorflow_types", ":tensorflow_types",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/eager:c_api",
"//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:SideEffects", "@llvm-project//mlir:SideEffects",
@ -1480,13 +1473,39 @@ cc_library(
) )
cc_library( cc_library(
name = "tf_dialect_lib", name = "tf_decode_attributes_hook",
srcs = [
"transforms/decode_attributes_hook.cc",
],
deps = [ deps = [
":tf_dialect_passes", ":convert_tensor",
":decode_constant_pass",
":tensorflow",
"//tensorflow/core:framework",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "tf_dialect_hooks",
deps = [
":tf_constant_fallback_hook",
":tf_decode_attributes_hook",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
], ],
) )
# TODO(jpienaar): Remove post updating all.
alias(
name = "tf_dialect_lib",
actual = ":tf_dialect_hooks",
)
cc_library( cc_library(
name = "tf_graph_optimization_pass", name = "tf_graph_optimization_pass",
srcs = ["transforms/tf_graph_optimization_pass.cc"], srcs = ["transforms/tf_graph_optimization_pass.cc"],
@ -1702,8 +1721,8 @@ cc_library(
name = "compile_mlir_util", name = "compile_mlir_util",
hdrs = ["utils/compile_mlir_util.h"], hdrs = ["utils/compile_mlir_util.h"],
deps = COMPILE_MLIR_UTIL_DEPS + [ deps = COMPILE_MLIR_UTIL_DEPS + [
"compile_mlir_util_no_tf_dialect_passes", ":compile_mlir_util_no_tf_dialect_passes",
":tf_dialect_passes", ":tf_dialect_hooks",
], ],
) )

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"

View File

@ -25,9 +25,9 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace TF { namespace TF {
LogicalResult ConstantFoldFallbackHook( LogicalResult ConstantFoldFallbackHook(Operation *inst,
Operation *inst, ArrayRef<Attribute> operands, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results); // NOLINT SmallVectorImpl<OpFoldResult> &results);
} // namespace TF } // namespace TF
} // namespace mlir } // namespace mlir

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/core/framework/logging.h" #include "tensorflow/core/framework/logging.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"

View File

@ -175,8 +175,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_hooks",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework", "//tensorflow/core:framework",