From 771ba8772a7308a5c39e9cc6852ab9b2478398f2 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 4 Feb 2020 15:42:23 -0800 Subject: [PATCH] [TF:MLIR:CPU] Add mark function visibility pass to tfcompile. Enable tffunction test for tfcompile-mlir. PiperOrigin-RevId: 293245780 Change-Id: Id28d8dfebc78700a28dbe01be7be2d9f77e01c1f --- tensorflow/compiler/aot/tests/BUILD | 13 +++++++++++++ tensorflow/compiler/aot/tests/tfcompile_test.cc | 4 +--- tensorflow/compiler/tf2xla/mlir_tf2xla.cc | 7 ++++++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 2f1e69d9ff1..a59176a8ece 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -349,6 +349,18 @@ tf_library( ], ) +tf_library( + name = "test_graph_tffunction_mlir_bridge", + testonly = 1, + config = "test_graph_tffunction.config.pbtxt", + cpp_class = "FunctionComp", + graph = "test_graph_tffunction.pb", + mlir_components = "Bridge", + tags = [ + "manual", + ], +) + tf_library( name = "test_graph_tfassert_eq_mlir_bridge", testonly = 1, @@ -484,6 +496,7 @@ tf_cc_test( ":test_graph_tfadd_with_ckpt_saver_mlir_bridge", ":test_graph_tfassert_eq_mlir_bridge", ":test_graph_tfcond_mlir_bridge", + ":test_graph_tffunction_mlir_bridge", ":test_graph_tfgather_mlir_bridge", ":test_graph_tfmatmul_mlir_bridge", ":test_graph_tfmatmulandadd_mlir_bridge", diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index b376f107c97..870fdc30053 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h" +#include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h" @@ -429,8 +430,6 @@ TEST(TFCompileTest, MatMulAndAdd1) { } } -// TODO(bixia): the following tests failed with MLIR bridge. -#if !defined(ENABLE_MLIR_BRIDGE_TEST) TEST(TFCompileTest, Function) { // The function is equivalent to an addition FunctionComp add_fn; @@ -445,7 +444,6 @@ TEST(TFCompileTest, Function) { EXPECT_EQ(add_fn.result0_data()[0], 3); EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]); } -#endif TEST(TFCompileTest, Splits) { Eigen::ThreadPool tp(1); diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index c2005304d65..6443e6cb8af 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -108,6 +108,11 @@ Status ConvertGraphDefToXlaViaMlir(const GraphDef& graph_def, device_set.AddDevice(&device); AddDevicesToOp(*module, &device_set); + if (failed(mlir::TF::MarkFunctionVisibilityUsingEntryFunctionSpecification( + *module))) { + return errors::Internal("Problem with mark function visibility"); + } + TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline( *module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));