Split out dialect hooks into separate targets

This allows not linking in the hooks too if, for example, one wants the TF dialect but not constant folding via fallback hook.

PiperOrigin-RevId: 347295194
Change-Id: Iaf5af9c4c0c0ed00e5cc91ecc39cc4043c5ca0b6
This commit is contained in:
Jacques Pienaar 2020-12-13 17:25:16 -08:00 committed by TensorFlower Gardener
parent 3727302f03
commit edc060801f
7 changed files with 36 additions and 22 deletions
tensorflow/compiler/mlir

View File

@ -109,7 +109,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_pass",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
"//tensorflow/compiler/mlir/tosa:tf_passes",

View File

@ -937,8 +937,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_hooks",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",

View File

@ -1106,7 +1106,6 @@ cc_library(
":mlir_roundtrip_flags",
":tensorflow",
":tensorflow_attributes",
":tensorflow_passes",
":tensorflow_types",
":tf_saved_model_passes",
":translate_utils",
@ -1450,27 +1449,21 @@ cc_library(
)
cc_library(
name = "tf_dialect_passes",
name = "tf_constant_fallback_hook",
srcs = [
"transforms/constant_fold.cc",
"transforms/decode_attributes_hook.cc",
],
hdrs = [
"transforms/constant_fold.h",
],
deps = [
":convert_tensor",
":decode_constant_pass",
":eval_util",
":tensorflow",
":tensorflow_traits",
":tensorflow_types",
"//tensorflow/c:tf_status",
"//tensorflow/c/eager:c_api",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SideEffects",
@ -1480,13 +1473,39 @@ cc_library(
)
cc_library(
name = "tf_dialect_lib",
name = "tf_decode_attributes_hook",
srcs = [
"transforms/decode_attributes_hook.cc",
],
deps = [
":tf_dialect_passes",
":convert_tensor",
":decode_constant_pass",
":tensorflow",
"//tensorflow/core:framework",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "tf_dialect_hooks",
deps = [
":tf_constant_fallback_hook",
":tf_decode_attributes_hook",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
],
)
# TODO(jpienaar): Remove post updating all.
alias(
name = "tf_dialect_lib",
actual = ":tf_dialect_hooks",
)
cc_library(
name = "tf_graph_optimization_pass",
srcs = ["transforms/tf_graph_optimization_pass.cc"],
@ -1702,8 +1721,8 @@ cc_library(
name = "compile_mlir_util",
hdrs = ["utils/compile_mlir_util.h"],
deps = COMPILE_MLIR_UTIL_DEPS + [
"compile_mlir_util_no_tf_dialect_passes",
":tf_dialect_passes",
":compile_mlir_util_no_tf_dialect_passes",
":tf_dialect_hooks",
],
)

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"

View File

@ -25,9 +25,9 @@ limitations under the License.
namespace mlir {
namespace TF {
LogicalResult ConstantFoldFallbackHook(
Operation *inst, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results); // NOLINT
LogicalResult ConstantFoldFallbackHook(Operation *inst,
ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results);
} // namespace TF
} // namespace mlir

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/core/framework/logging.h"
#include "tensorflow/stream_executor/lib/statusor.h"

View File

@ -175,8 +175,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_hooks",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework",