[TF:MLIR:CPU] Add mark function visibility pass to tfcompile.

Enable tffunction test for tfcompile-mlir.

PiperOrigin-RevId: 293245780
Change-Id: Id28d8dfebc78700a28dbe01be7be2d9f77e01c1f
This commit is contained in:
Bixia Zheng 2020-02-04 15:42:23 -08:00 committed by TensorFlower Gardener
parent 074b026b77
commit 771ba8772a
3 changed files with 20 additions and 4 deletions

View File

@ -349,6 +349,18 @@ tf_library(
],
)
tf_library(
name = "test_graph_tffunction_mlir_bridge",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfassert_eq_mlir_bridge",
testonly = 1,
@ -484,6 +496,7 @@ tf_cc_test(
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
":test_graph_tfassert_eq_mlir_bridge",
":test_graph_tfcond_mlir_bridge",
":test_graph_tffunction_mlir_bridge",
":test_graph_tfgather_mlir_bridge",
":test_graph_tfmatmul_mlir_bridge",
":test_graph_tfmatmulandadd_mlir_bridge",

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
@ -429,8 +430,6 @@ TEST(TFCompileTest, MatMulAndAdd1) {
}
}
// TODO(bixia): the following tests failed with MLIR bridge.
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
@ -445,7 +444,6 @@ TEST(TFCompileTest, Function) {
EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
}
#endif
TEST(TFCompileTest, Splits) {
Eigen::ThreadPool tp(1);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@ -108,6 +108,11 @@ Status ConvertGraphDefToXlaViaMlir(const GraphDef& graph_def,
device_set.AddDevice(&device);
AddDevicesToOp(*module, &device_set);
if (failed(mlir::TF::MarkFunctionVisibilityUsingEntryFunctionSpecification(
*module))) {
return errors::Internal("Problem with mark function visibility");
}
TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline(
*module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));