diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index dc327a2d519..2f1e69d9ff1 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -447,6 +447,30 @@ tf_library( ], ) +tf_library( + name = "test_graph_tfvariable_mlir_bridge", + testonly = 1, + config = "test_graph_tfvariable.config.pbtxt", + cpp_class = "VariableComp", + graph = "test_graph_tfvariable.pb", + mlir_components = "Bridge", + tags = [ + "manual", + ], +) + +tf_library( + name = "test_graph_tfvariable_sequential_updates_mlir_bridge", + testonly = 1, + config = "test_graph_tfvariable_sequential_updates.config.pbtxt", + cpp_class = "VariableSequentialUpdatesComp", + graph = "test_graph_tfvariable_sequential_updates.pb", + mlir_components = "Bridge", + tags = [ + "manual", + ], +) + tf_cc_test( name = "tfcompile_test_mlir_bridge", srcs = ["tfcompile_test.cc"], @@ -466,7 +490,9 @@ tf_cc_test( ":test_graph_tfmatmulandadd_with_profiling_mlir_bridge", ":test_graph_tfsplits_mlir_bridge", ":test_graph_tftop_k_mlir_bridge", + ":test_graph_tfvariable_mlir_bridge", ":test_graph_tfvariable_readonly_mlir_bridge", + ":test_graph_tfvariable_sequential_updates_mlir_bridge", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto_cc", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index ae8c40c426c..a96ba0e6919 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -162,11 +162,13 @@ def tfvariable_readonly(_): array_ops.identity(new_value, name='result') +# TODO(b/147908587): Change x and the two constants back to have a scalar shape +# when the bug is fixed. def tfvariable(_): - x = variables.Variable(1000.0, name='x') + x = variables.Variable([1000.0], name='x', shape=[1]) old_x = x.value() with ops.control_dependencies([old_x]): - new_x = x.assign_add(42.0) + new_x = x.assign_add([42.0]) array_ops.stack([old_x, new_x], name='result') diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index e4de9ea0f8b..b376f107c97 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -38,7 +38,9 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h" #else #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" @@ -511,8 +513,6 @@ TEST(TFCompileTest, VariableReadonly) { EXPECT_EQ(fn.var_x(), 23); } -// TODO(bixia): the following tests failed with MLIR bridge. -#if !defined(ENABLE_MLIR_BRIDGE_TEST) TEST(TFCompileTest, Variable) { Eigen::ThreadPool tp(1); Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); @@ -585,7 +585,6 @@ TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) { fn.Run(); EXPECT_NEAR(x, 0.594322f, 1e-6); } -#endif TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the