Add a new unittest in mark_for_compilation_pass_test.
The new test tests that ClusterScopingPass works and MarkForCompilationPass accordinly preserves the required clustering scopes.
This commit is contained in:
parent
230cbd8568
commit
50429bdff1
@ -1718,5 +1718,91 @@ TEST(XlaCompilationTest, UnsupportedEnterExitPattern) {
|
||||
EXPECT_EQ(0, clusters.size());
|
||||
}
|
||||
|
||||
namespace {
|
||||
Node* MakeStageNode(GraphDefBuilder& builder, string name,
|
||||
std::initializer_list<DataType> dtypes,
|
||||
gtl::ArraySlice<ops::NodeOut> values) {
|
||||
auto opts = builder.opts()
|
||||
.WithName(std::move(name))
|
||||
.WithAttr("dtypes", std::move(dtypes));
|
||||
if (opts.HaveError()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
NodeBuilder node_builder(name, "Stage", opts.op_registry());
|
||||
node_builder.Input(values);
|
||||
return opts.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
|
||||
auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status {
|
||||
// Construct a graph as below with two pipeline stages and test that nodes
|
||||
// in different stages will not be merged if ClusterScopingPass is on.
|
||||
//
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// a -> add0 -> relu0 -> stage
|
||||
//
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// unstage -> add1 -> relu1
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("a")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("b")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* unstage = ops::SourceOp(
|
||||
"Unstage",
|
||||
builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
|
||||
|
||||
Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
|
||||
Node* add1 =
|
||||
ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
|
||||
Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
|
||||
ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
|
||||
MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0});
|
||||
|
||||
return GraphDefBuilderToGraph(builder, graph->get());
|
||||
};
|
||||
|
||||
// All nodes go into the same cluster if ClusterScopingPass is off.
|
||||
{
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(build_staged_graph(&graph));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
&graph,
|
||||
MarkForCompilationPassTestHelper::Options().WithNoClusterScoping()));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(clusters["add0"], clusters["add1"]);
|
||||
EXPECT_EQ(clusters["add0"], clusters["relu1"]);
|
||||
EXPECT_EQ(clusters["relu0"], clusters["add1"]);
|
||||
EXPECT_EQ(clusters["relu0"], clusters["relu1"]);
|
||||
}
|
||||
|
||||
// By default, ClusterScopingPass is on and different pipeline stages should
|
||||
// not be merged.
|
||||
{
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(build_staged_graph(&graph));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_NE(clusters["add0"], clusters["add1"]);
|
||||
EXPECT_NE(clusters["add0"], clusters["relu1"]);
|
||||
EXPECT_NE(clusters["relu0"], clusters["add1"]);
|
||||
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
@ -48,8 +50,14 @@ namespace tensorflow {
|
||||
opt_options.graph = graph;
|
||||
opt_options.session_options = &session_options;
|
||||
opt_options.flib_def = flib_def;
|
||||
MarkForCompilationPass pass;
|
||||
return pass.RunForTest(
|
||||
|
||||
if (options.enable_cluster_scoping) {
|
||||
ClusterScopingPass cluster_scoping_pass;
|
||||
TF_RETURN_IF_ERROR(cluster_scoping_pass.Run(opt_options));
|
||||
}
|
||||
|
||||
MarkForCompilationPass mark_for_compilation_pass;
|
||||
return mark_for_compilation_pass.RunForTest(
|
||||
opt_options,
|
||||
/*disable_deadness_analysis=*/options.disable_deadness_analysis);
|
||||
}
|
||||
|
@ -24,8 +24,12 @@ class MarkForCompilationPassTestHelper {
|
||||
struct Options {
|
||||
bool enable_global_jit;
|
||||
bool disable_deadness_analysis;
|
||||
bool enable_cluster_scoping;
|
||||
|
||||
Options() : enable_global_jit(true), disable_deadness_analysis(true) {}
|
||||
Options()
|
||||
: enable_global_jit(true),
|
||||
disable_deadness_analysis(true),
|
||||
enable_cluster_scoping(true) {}
|
||||
|
||||
Options WithNoGlobalJit() {
|
||||
Options copy = *this;
|
||||
@ -38,6 +42,12 @@ class MarkForCompilationPassTestHelper {
|
||||
copy.disable_deadness_analysis = false;
|
||||
return copy;
|
||||
}
|
||||
|
||||
Options WithNoClusterScoping() {
|
||||
Options copy = *this;
|
||||
copy.enable_cluster_scoping = false;
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
|
||||
// Runs the MarkForCompilation pass on `graph` after assigning all nodes in
|
||||
|
Loading…
Reference in New Issue
Block a user