Make function tests a little less brittle.

Function tests depended on a global counter deep in constant folding
logic being just the right value. The value depends on exactly what
other tests involving constant folding run in the same process.

This fix adds a function that constant folds a small graph and
parses the counter value from resulting graph node name. The tests
can they base their expectations off this value. This is still not ideal
as multiple tests can theoretically run in parallel, but better than
before.

PiperOrigin-RevId: 218898255
This commit is contained in:
Igor Ganichev 2018-10-26 12:46:32 -07:00 committed by TensorFlower Gardener
parent 68881e62a8
commit d13e9307d2
2 changed files with 43 additions and 4 deletions

View File

@ -4368,6 +4368,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops",
@ -4375,6 +4376,7 @@ tf_cc_test(
"//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:shape_ops",
"//third_party/eigen3",
"@com_google_absl//absl/strings",
],
)

View File

@ -18,10 +18,14 @@ limitations under the License.
#include <atomic>
#include <utility>
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
@ -888,18 +892,51 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
EXPECT_EQ(expected_node_names, executed_node_names);
}
// Constant folding generates names using a global counter.
// This function invokes constant folding and parses the counter
// from the generated node name.
int GetConstantFoldingCounter() {
Graph g(OpRegistry::Global());
Scope s = Scope::NewRootScope();
auto a = ops::Const<float>(s, {1.0}, {});
auto b = ops::Const<float>(s, {2.0}, {});
auto add = ops::Add(s.WithOpName("add"), a, b);
auto send =
ops::_Send(s.WithOpName("s1"), add, "add", "sender", 0, "receiver");
TF_CHECK_OK(s.ToGraph(&g));
bool was_mutated;
ConstantFoldingOptions opt{};
TF_CHECK_OK(
ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated));
GraphDef def;
g.ToGraphDef(&def);
for (const NodeDef& node : def.node()) {
if (absl::StartsWith(node.name(), "add/")) {
std::vector<std::string> v = absl::StrSplit(node.name(), "__cf__");
CHECK_GT(v.size(), 1);
int counter;
CHECK(absl::SimpleAtoi(v[v.size() - 1], &counter));
return counter;
}
}
LOG(FATAL) << "Should have found a node that replcaed add";
}
TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}});
ASSERT_TRUE(g != nullptr);
ExpandInlineFunctions(flr0_, g.get());
int cf_counter = GetConstantFoldingCounter();
OptimizeGraph(flr0_, &g);
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
s.WithOpName("x4/x2/scale/_12__cf__13")
s.WithOpName("x4/x2/scale/_12__cf__" + std::to_string(cf_counter + 1))
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@ -1099,20 +1136,20 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
TF_EXPECT_GRAPH_EQ(expected, actual);
}
int cf_counter = GetConstantFoldingCounter();
OptimizeGraph(flr0_, &g);
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
s.WithOpName("scale/_6__cf__18")
s.WithOpName("scale/_6__cf__" + std::to_string(cf_counter + 2))
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
s.WithOpName("Func/_1/sy/_5__cf__17")
s.WithOpName("Func/_1/sy/_5__cf__" + std::to_string(cf_counter + 1))
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(