[TF:MLIR:CPU] Enable read/write resource variable test for tfcompile-mlir.
Modify the python program that generates a graph for tfvariable test to use the same shape that the test uses to workaround a bug in the GraphDef to MLIR importer. PiperOrigin-RevId: 291217424 Change-Id: Id372a6846e930dc77f5042eeadc71dc6a0cb8dbb
This commit is contained in:
parent
816bd4ba34
commit
abf7616950
@ -447,6 +447,30 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfvariable.config.pbtxt",
|
||||
cpp_class = "VariableComp",
|
||||
graph = "test_graph_tfvariable.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_sequential_updates_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
|
||||
cpp_class = "VariableSequentialUpdatesComp",
|
||||
graph = "test_graph_tfvariable_sequential_updates.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tfcompile_test_mlir_bridge",
|
||||
srcs = ["tfcompile_test.cc"],
|
||||
@ -466,7 +490,9 @@ tf_cc_test(
|
||||
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
|
||||
":test_graph_tfsplits_mlir_bridge",
|
||||
":test_graph_tftop_k_mlir_bridge",
|
||||
":test_graph_tfvariable_mlir_bridge",
|
||||
":test_graph_tfvariable_readonly_mlir_bridge",
|
||||
":test_graph_tfvariable_sequential_updates_mlir_bridge",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
|
@ -162,11 +162,13 @@ def tfvariable_readonly(_):
|
||||
array_ops.identity(new_value, name='result')
|
||||
|
||||
|
||||
# TODO(b/147908587): Change x and the two constants back to have a scalar shape
|
||||
# when the bug is fixed.
|
||||
def tfvariable(_):
|
||||
x = variables.Variable(1000.0, name='x')
|
||||
x = variables.Variable([1000.0], name='x', shape=[1])
|
||||
old_x = x.value()
|
||||
with ops.control_dependencies([old_x]):
|
||||
new_x = x.assign_add(42.0)
|
||||
new_x = x.assign_add([42.0])
|
||||
array_ops.stack([old_x, new_x], name='result')
|
||||
|
||||
|
||||
|
@ -38,7 +38,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h"
|
||||
#else
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
||||
@ -511,8 +513,6 @@ TEST(TFCompileTest, VariableReadonly) {
|
||||
EXPECT_EQ(fn.var_x(), 23);
|
||||
}
|
||||
|
||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||
TEST(TFCompileTest, Variable) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
@ -585,7 +585,6 @@ TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
|
||||
fn.Run();
|
||||
EXPECT_NEAR(x, 0.594322f, 1e-6);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
||||
// Assert is converted into a no-op in XLA, so there is no failure even if the
|
||||
|
Loading…
Reference in New Issue
Block a user