[TF:MLIR:CPU] Add mark function visibility pass to tfcompile.
Enable tffunction test for tfcompile-mlir. PiperOrigin-RevId: 293245780 Change-Id: Id28d8dfebc78700a28dbe01be7be2d9f77e01c1f
This commit is contained in:
parent
074b026b77
commit
771ba8772a
@ -349,6 +349,18 @@ tf_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tffunction_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tffunction.config.pbtxt",
|
||||||
|
cpp_class = "FunctionComp",
|
||||||
|
graph = "test_graph_tffunction.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_library(
|
tf_library(
|
||||||
name = "test_graph_tfassert_eq_mlir_bridge",
|
name = "test_graph_tfassert_eq_mlir_bridge",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
@ -484,6 +496,7 @@ tf_cc_test(
|
|||||||
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
||||||
":test_graph_tfassert_eq_mlir_bridge",
|
":test_graph_tfassert_eq_mlir_bridge",
|
||||||
":test_graph_tfcond_mlir_bridge",
|
":test_graph_tfcond_mlir_bridge",
|
||||||
|
":test_graph_tffunction_mlir_bridge",
|
||||||
":test_graph_tfgather_mlir_bridge",
|
":test_graph_tfgather_mlir_bridge",
|
||||||
":test_graph_tfmatmul_mlir_bridge",
|
":test_graph_tfmatmul_mlir_bridge",
|
||||||
":test_graph_tfmatmulandadd_mlir_bridge",
|
":test_graph_tfmatmulandadd_mlir_bridge",
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
||||||
@ -429,8 +430,6 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
|
||||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
|
||||||
TEST(TFCompileTest, Function) {
|
TEST(TFCompileTest, Function) {
|
||||||
// The function is equivalent to an addition
|
// The function is equivalent to an addition
|
||||||
FunctionComp add_fn;
|
FunctionComp add_fn;
|
||||||
@ -445,7 +444,6 @@ TEST(TFCompileTest, Function) {
|
|||||||
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||||
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
TEST(TFCompileTest, Splits) {
|
TEST(TFCompileTest, Splits) {
|
||||||
Eigen::ThreadPool tp(1);
|
Eigen::ThreadPool tp(1);
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
@ -28,7 +29,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
|
|
||||||
@ -108,6 +108,11 @@ Status ConvertGraphDefToXlaViaMlir(const GraphDef& graph_def,
|
|||||||
device_set.AddDevice(&device);
|
device_set.AddDevice(&device);
|
||||||
AddDevicesToOp(*module, &device_set);
|
AddDevicesToOp(*module, &device_set);
|
||||||
|
|
||||||
|
if (failed(mlir::TF::MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||||
|
*module))) {
|
||||||
|
return errors::Internal("Problem with mark function visibility");
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline(
|
TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline(
|
||||||
*module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));
|
*module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user