[TF:MLIR:CPU] Add mark function visibility pass to tfcompile.
Enable tffunction test for tfcompile-mlir. PiperOrigin-RevId: 293245780 Change-Id: Id28d8dfebc78700a28dbe01be7be2d9f77e01c1f
This commit is contained in:
parent
074b026b77
commit
771ba8772a
@ -349,6 +349,18 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tffunction_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tffunction.config.pbtxt",
|
||||
cpp_class = "FunctionComp",
|
||||
graph = "test_graph_tffunction.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfassert_eq_mlir_bridge",
|
||||
testonly = 1,
|
||||
@ -484,6 +496,7 @@ tf_cc_test(
|
||||
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
||||
":test_graph_tfassert_eq_mlir_bridge",
|
||||
":test_graph_tfcond_mlir_bridge",
|
||||
":test_graph_tffunction_mlir_bridge",
|
||||
":test_graph_tfgather_mlir_bridge",
|
||||
":test_graph_tfmatmul_mlir_bridge",
|
||||
":test_graph_tfmatmulandadd_mlir_bridge",
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
||||
@ -429,8 +430,6 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||
TEST(TFCompileTest, Function) {
|
||||
// The function is equivalent to an addition
|
||||
FunctionComp add_fn;
|
||||
@ -445,7 +444,6 @@ TEST(TFCompileTest, Function) {
|
||||
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TFCompileTest, Splits) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
@ -28,7 +29,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
|
||||
@ -108,6 +108,11 @@ Status ConvertGraphDefToXlaViaMlir(const GraphDef& graph_def,
|
||||
device_set.AddDevice(&device);
|
||||
AddDevicesToOp(*module, &device_set);
|
||||
|
||||
if (failed(mlir::TF::MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||
*module))) {
|
||||
return errors::Internal("Problem with mark function visibility");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline(
|
||||
*module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user