diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index e0074545d33..77db4eb43be 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -110,6 +110,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_pass", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes", "//tensorflow/compiler/mlir/tosa:tf_passes", diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 72e5799c5c9..664dfe0e3ba 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -937,7 +937,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow:tf_dialect_hooks", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 301e3ba9151..c74d47404e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1107,6 +1107,7 @@ cc_library( ":mlir_roundtrip_flags", ":tensorflow", ":tensorflow_attributes", + ":tensorflow_passes", ":tensorflow_types", ":tf_saved_model_passes", ":translate_utils", @@ -1450,21 +1451,27 @@ cc_library( ) cc_library( - name = "tf_constant_fallback_hook", + name = "tf_dialect_passes", srcs = [ "transforms/constant_fold.cc", + "transforms/decode_attributes_hook.cc", ], hdrs = [ "transforms/constant_fold.h", ], deps = [ + ":convert_tensor", ":decode_constant_pass", ":eval_util", ":tensorflow", ":tensorflow_traits", ":tensorflow_types", "//tensorflow/c:tf_status", + "//tensorflow/c/eager:c_api", + "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/stream_executor", + "//tensorflow/stream_executor/lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SideEffects", @@ -1474,39 +1481,13 @@ cc_library( ) cc_library( - name = "tf_decode_attributes_hook", - srcs = [ - "transforms/decode_attributes_hook.cc", - ], + name = "tf_dialect_lib", deps = [ - ":convert_tensor", - ":decode_constant_pass", - ":tensorflow", - "//tensorflow/core:framework", - "//tensorflow/stream_executor", - "//tensorflow/stream_executor/lib", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], - alwayslink = 1, -) - -cc_library( - name = "tf_dialect_hooks", - deps = [ - ":tf_constant_fallback_hook", - ":tf_decode_attributes_hook", + ":tf_dialect_passes", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", ], ) -# TODO(jpienaar): Remove post updating all. -alias( - name = "tf_dialect_lib", - actual = ":tf_dialect_hooks", -) - cc_library( name = "tf_graph_optimization_pass", srcs = ["transforms/tf_graph_optimization_pass.cc"], @@ -1722,8 +1703,8 @@ cc_library( name = "compile_mlir_util", hdrs = ["utils/compile_mlir_util.h"], deps = COMPILE_MLIR_UTIL_DEPS + [ - ":compile_mlir_util_no_tf_dialect_passes", - ":tf_dialect_hooks", + "compile_mlir_util_no_tf_dialect_passes", + ":tf_dialect_passes", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index a3c487f6378..31cfc5ebf9c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h index 54f296dcb2f..887eea745e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h @@ -25,9 +25,9 @@ limitations under the License. namespace mlir { namespace TF { -LogicalResult ConstantFoldFallbackHook(Operation *inst, - ArrayRef operands, - SmallVectorImpl &results); +LogicalResult ConstantFoldFallbackHook( + Operation *inst, ArrayRef operands, + SmallVectorImpl &results); // NOLINT } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc index 9dbf332fc67..09fac6e0706 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/core/framework/logging.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD index 66b9a5493ce..a337dc02a9e 100644 --- a/tensorflow/compiler/mlir/tfjs/BUILD +++ b/tensorflow/compiler/mlir/tfjs/BUILD @@ -175,7 +175,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:tf_dialect_hooks", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:framework",