Make function tests a little less brittle.
Function tests depended on a global counter deep in constant folding logic being just the right value. The value depends on exactly what other tests involving constant folding run in the same process. This fix adds a function that constant folds a small graph and parses the counter value from resulting graph node name. The tests can they base their expectations off this value. This is still not ideal as multiple tests can theoretically run in parallel, but better than before. PiperOrigin-RevId: 218898255
This commit is contained in:
parent
68881e62a8
commit
d13e9307d2
@ -4368,6 +4368,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:cast_op",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
@ -4375,6 +4376,7 @@ tf_cc_test(
|
||||
"//tensorflow/core/kernels:random_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,10 +18,14 @@ limitations under the License.
|
||||
#include <atomic>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
@ -888,18 +892,51 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
|
||||
EXPECT_EQ(expected_node_names, executed_node_names);
|
||||
}
|
||||
|
||||
// Constant folding generates names using a global counter.
|
||||
// This function invokes constant folding and parses the counter
|
||||
// from the generated node name.
|
||||
int GetConstantFoldingCounter() {
|
||||
Graph g(OpRegistry::Global());
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto a = ops::Const<float>(s, {1.0}, {});
|
||||
auto b = ops::Const<float>(s, {2.0}, {});
|
||||
|
||||
auto add = ops::Add(s.WithOpName("add"), a, b);
|
||||
auto send =
|
||||
ops::_Send(s.WithOpName("s1"), add, "add", "sender", 0, "receiver");
|
||||
|
||||
TF_CHECK_OK(s.ToGraph(&g));
|
||||
bool was_mutated;
|
||||
ConstantFoldingOptions opt{};
|
||||
TF_CHECK_OK(
|
||||
ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
GraphDef def;
|
||||
g.ToGraphDef(&def);
|
||||
for (const NodeDef& node : def.node()) {
|
||||
if (absl::StartsWith(node.name(), "add/")) {
|
||||
std::vector<std::string> v = absl::StrSplit(node.name(), "__cf__");
|
||||
CHECK_GT(v.size(), 1);
|
||||
int counter;
|
||||
CHECK(absl::SimpleAtoi(v[v.size() - 1], &counter));
|
||||
return counter;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Should have found a node that replcaed add";
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
|
||||
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
|
||||
test::function::XTimes16()});
|
||||
std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}});
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
ExpandInlineFunctions(flr0_, g.get());
|
||||
int cf_counter = GetConstantFoldingCounter();
|
||||
OptimizeGraph(flr0_, &g);
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
|
||||
auto x4_x2_scale = ops::Const<float>(
|
||||
s.WithOpName("x4/x2/scale/_12__cf__13")
|
||||
s.WithOpName("x4/x2/scale/_12__cf__" + std::to_string(cf_counter + 1))
|
||||
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||
2.0f);
|
||||
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
|
||||
@ -1099,20 +1136,20 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
|
||||
TF_EXPECT_GRAPH_EQ(expected, actual);
|
||||
}
|
||||
|
||||
int cf_counter = GetConstantFoldingCounter();
|
||||
OptimizeGraph(flr0_, &g);
|
||||
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
|
||||
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
|
||||
auto scale = ops::Const(
|
||||
s.WithOpName("scale/_6__cf__18")
|
||||
s.WithOpName("scale/_6__cf__" + std::to_string(cf_counter + 2))
|
||||
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||
2.0f);
|
||||
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
|
||||
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
|
||||
auto const0 = ops::Const(
|
||||
s.WithOpName("Func/_1/sy/_5__cf__17")
|
||||
s.WithOpName("Func/_1/sy/_5__cf__" + std::to_string(cf_counter + 1))
|
||||
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||
0, {0});
|
||||
auto func1_rx = ops::internal::BroadcastGradientArgs(
|
||||
|
Loading…
x
Reference in New Issue
Block a user