Add test with tf.cond.
PiperOrigin-RevId: 195745718
This commit is contained in:
parent
3964bdeef8
commit
db63348bf1
@ -15,6 +15,7 @@ test_suite(
|
||||
":test_graph_tfadd_with_ckpt_saver_test",
|
||||
":test_graph_tfadd_with_ckpt_test",
|
||||
":test_graph_tfassert_eq_test",
|
||||
":test_graph_tfcond_test",
|
||||
":test_graph_tffunction_test",
|
||||
":test_graph_tfgather_test",
|
||||
":test_graph_tfmatmul_test",
|
||||
@ -55,6 +56,7 @@ genrule(
|
||||
"test_graph_tfadd_with_ckpt_saver.pb",
|
||||
"test_graph_tfadd_with_ckpt_saver.saver",
|
||||
"test_graph_tfassert_eq.pb",
|
||||
"test_graph_tfcond.pb",
|
||||
"test_graph_tffunction.pb",
|
||||
"test_graph_tfgather.pb",
|
||||
"test_graph_tfmatmul.pb",
|
||||
@ -118,6 +120,17 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfcond",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfcond.config.pbtxt",
|
||||
cpp_class = "CondComp",
|
||||
graph = "test_graph_tfcond.pb",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tffunction",
|
||||
testonly = 1,
|
||||
@ -194,6 +207,7 @@ tf_cc_test(
|
||||
":test_graph_tfadd_with_ckpt",
|
||||
":test_graph_tfadd_with_ckpt_saver",
|
||||
":test_graph_tfassert_eq",
|
||||
":test_graph_tfcond",
|
||||
":test_graph_tffunction",
|
||||
":test_graph_tfgather",
|
||||
":test_graph_tfmatmul",
|
||||
|
@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir):
|
||||
f.write(saver.as_saver_def().SerializeToString())
|
||||
|
||||
|
||||
def tfassert_eq(_):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x_hold')
|
||||
y = array_ops.placeholder(dtypes.int32, name='y_hold')
|
||||
control_flow_ops.Assert(
|
||||
math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
|
||||
math_ops.add(x, math_ops.negative(y), name='x_y_diff')
|
||||
|
||||
|
||||
def tfcond(_):
|
||||
p = array_ops.placeholder(dtypes.bool, name='p_hold')
|
||||
x = array_ops.placeholder(dtypes.int32, name='x_hold')
|
||||
y = array_ops.placeholder(dtypes.int32, name='y_hold')
|
||||
z = control_flow_ops.cond(p, lambda: x, lambda: y)
|
||||
array_ops.identity(z, name='result')
|
||||
|
||||
|
||||
def tfgather(_):
|
||||
params = array_ops.placeholder(dtypes.float32, name='params')
|
||||
indices = array_ops.placeholder(dtypes.int32, name='indices')
|
||||
@ -126,14 +142,6 @@ def tfsplits(_):
|
||||
array_ops.identity(y, name='result')
|
||||
|
||||
|
||||
def tfassert_eq(_):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x_hold')
|
||||
y = array_ops.placeholder(dtypes.int32, name='y_hold')
|
||||
control_flow_ops.Assert(
|
||||
math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
|
||||
math_ops.add(x, math_ops.negative(y), name='x_y_diff')
|
||||
|
||||
|
||||
def write_graph(build_graph, out_dir):
|
||||
"""Build a graph using build_graph and write it out."""
|
||||
g = ops.Graph()
|
||||
@ -148,12 +156,13 @@ def main(_):
|
||||
write_graph(tfadd, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||
write_graph(tfassert_eq, FLAGS.out_dir)
|
||||
write_graph(tfcond, FLAGS.out_dir)
|
||||
write_graph(tffunction, FLAGS.out_dir)
|
||||
write_graph(tfgather, FLAGS.out_dir)
|
||||
write_graph(tfmatmul, FLAGS.out_dir)
|
||||
write_graph(tfmatmulandadd, FLAGS.out_dir)
|
||||
write_graph(tffunction, FLAGS.out_dir)
|
||||
write_graph(tfsplits, FLAGS.out_dir)
|
||||
write_graph(tfassert_eq, FLAGS.out_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
20
tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt
Normal file
20
tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt
Normal file
@ -0,0 +1,20 @@
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "p_hold" }
|
||||
shape {}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "x_hold" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y_hold" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "result" }
|
||||
}
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
|
||||
@ -150,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) {
|
||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, Cond) {
|
||||
CondComp cond;
|
||||
EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
|
||||
EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
|
||||
EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
|
||||
cond.arg1() = 10;
|
||||
cond.arg2() = 20;
|
||||
{
|
||||
cond.arg0() = true;
|
||||
const int32 expected_result = cond.arg1();
|
||||
EXPECT_TRUE(cond.Run());
|
||||
EXPECT_EQ(cond.result0(), expected_result);
|
||||
EXPECT_EQ(cond.result0_data()[0], expected_result);
|
||||
EXPECT_EQ(cond.result0_data(), cond.results()[0]);
|
||||
}
|
||||
{
|
||||
cond.arg0() = false;
|
||||
const int32 expected_result = cond.arg2();
|
||||
EXPECT_TRUE(cond.Run());
|
||||
EXPECT_EQ(cond.result0(), expected_result);
|
||||
EXPECT_EQ(cond.result0_data()[0], expected_result);
|
||||
EXPECT_EQ(cond.result0_data(), cond.results()[0]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, Gather) {
|
||||
GatherComp gather;
|
||||
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
|
||||
|
Loading…
Reference in New Issue
Block a user