Add test with tf.cond.

PiperOrigin-RevId: 195745718
This commit is contained in:
Jacques Pienaar 2018-05-07 16:59:41 -07:00 committed by TensorFlower Gardener
parent 3964bdeef8
commit db63348bf1
4 changed files with 79 additions and 10 deletions

View File

@ -15,6 +15,7 @@ test_suite(
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
":test_graph_tfassert_eq_test",
":test_graph_tfcond_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
@ -55,6 +56,7 @@ genrule(
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
"test_graph_tfassert_eq.pb",
"test_graph_tfcond.pb",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
@ -118,6 +120,17 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfcond",
testonly = 1,
config = "test_graph_tfcond.config.pbtxt",
cpp_class = "CondComp",
graph = "test_graph_tfcond.pb",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tffunction",
testonly = 1,
@ -194,6 +207,7 @@ tf_cc_test(
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tfassert_eq",
":test_graph_tfcond",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",

View File

@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir):
f.write(saver.as_saver_def().SerializeToString())
def tfassert_eq(_):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
control_flow_ops.Assert(
math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
math_ops.add(x, math_ops.negative(y), name='x_y_diff')
def tfcond(_):
p = array_ops.placeholder(dtypes.bool, name='p_hold')
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
z = control_flow_ops.cond(p, lambda: x, lambda: y)
array_ops.identity(z, name='result')
def tfgather(_):
params = array_ops.placeholder(dtypes.float32, name='params')
indices = array_ops.placeholder(dtypes.int32, name='indices')
@ -126,14 +142,6 @@ def tfsplits(_):
array_ops.identity(y, name='result')
def tfassert_eq(_):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
control_flow_ops.Assert(
math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
math_ops.add(x, math_ops.negative(y), name='x_y_diff')
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@ -148,12 +156,13 @@ def main(_):
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)
write_graph(tfcond, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)
if __name__ == '__main__':

View File

@ -0,0 +1,20 @@
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "p_hold" }
shape {}
}
feed {
id { node_name: "x_hold" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_hold" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "result" }
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
@ -150,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
TEST(TFCompileTest, Cond) {
CondComp cond;
EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
cond.arg1() = 10;
cond.arg2() = 20;
{
cond.arg0() = true;
const int32 expected_result = cond.arg1();
EXPECT_TRUE(cond.Run());
EXPECT_EQ(cond.result0(), expected_result);
EXPECT_EQ(cond.result0_data()[0], expected_result);
EXPECT_EQ(cond.result0_data(), cond.results()[0]);
}
{
cond.arg0() = false;
const int32 expected_result = cond.arg2();
EXPECT_TRUE(cond.Run());
EXPECT_EQ(cond.result0(), expected_result);
EXPECT_EQ(cond.result0_data()[0], expected_result);
EXPECT_EQ(cond.result0_data(), cond.results()[0]);
}
}
TEST(TFCompileTest, Gather) {
GatherComp gather;
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);