Add a new unittest in mark_for_compilation_pass_test.

The new test tests that ClusterScopingPass works and MarkForCompilationPass
accordinly preserves the required clustering scopes.
This commit is contained in:
Trent Lo 2019-08-16 13:37:19 -07:00
parent 230cbd8568
commit 50429bdff1
3 changed files with 107 additions and 3 deletions

View File

@ -1718,5 +1718,91 @@ TEST(XlaCompilationTest, UnsupportedEnterExitPattern) {
EXPECT_EQ(0, clusters.size()); EXPECT_EQ(0, clusters.size());
} }
namespace {
Node* MakeStageNode(GraphDefBuilder& builder, string name,
std::initializer_list<DataType> dtypes,
gtl::ArraySlice<ops::NodeOut> values) {
auto opts = builder.opts()
.WithName(std::move(name))
.WithAttr("dtypes", std::move(dtypes));
if (opts.HaveError()) {
return nullptr;
}
NodeBuilder node_builder(name, "Stage", opts.op_registry());
node_builder.Input(values);
return opts.FinalizeBuilder(&node_builder);
}
} // namespace
TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status {
// Construct a graph as below with two pipeline stages and test that nodes
// in different stages will not be merged if ClusterScopingPass is on.
//
// b
// |
// v
// a -> add0 -> relu0 -> stage
//
// b
// |
// v
// unstage -> add1 -> relu1
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("a")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b = ops::SourceOp("Const", builder.opts()
.WithName("b")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* unstage = ops::SourceOp(
"Unstage",
builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
Node* add1 =
ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0});
return GraphDefBuilderToGraph(builder, graph->get());
};
// All nodes go into the same cluster if ClusterScopingPass is off.
{
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(build_staged_graph(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
&graph,
MarkForCompilationPassTestHelper::Options().WithNoClusterScoping()));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_EQ(clusters["add0"], clusters["add1"]);
EXPECT_EQ(clusters["add0"], clusters["relu1"]);
EXPECT_EQ(clusters["relu0"], clusters["add1"]);
EXPECT_EQ(clusters["relu0"], clusters["relu1"]);
}
// By default, ClusterScopingPass is on and different pipeline stages should
// not be merged.
{
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(build_staged_graph(&graph));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["add0"], clusters["add1"]);
EXPECT_NE(clusters["add0"], clusters["relu1"]);
EXPECT_NE(clusters["relu0"], clusters["add1"]);
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
}
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
@ -48,8 +50,14 @@ namespace tensorflow {
opt_options.graph = graph; opt_options.graph = graph;
opt_options.session_options = &session_options; opt_options.session_options = &session_options;
opt_options.flib_def = flib_def; opt_options.flib_def = flib_def;
MarkForCompilationPass pass;
return pass.RunForTest( if (options.enable_cluster_scoping) {
ClusterScopingPass cluster_scoping_pass;
TF_RETURN_IF_ERROR(cluster_scoping_pass.Run(opt_options));
}
MarkForCompilationPass mark_for_compilation_pass;
return mark_for_compilation_pass.RunForTest(
opt_options, opt_options,
/*disable_deadness_analysis=*/options.disable_deadness_analysis); /*disable_deadness_analysis=*/options.disable_deadness_analysis);
} }

View File

@ -24,8 +24,12 @@ class MarkForCompilationPassTestHelper {
struct Options { struct Options {
bool enable_global_jit; bool enable_global_jit;
bool disable_deadness_analysis; bool disable_deadness_analysis;
bool enable_cluster_scoping;
Options() : enable_global_jit(true), disable_deadness_analysis(true) {} Options()
: enable_global_jit(true),
disable_deadness_analysis(true),
enable_cluster_scoping(true) {}
Options WithNoGlobalJit() { Options WithNoGlobalJit() {
Options copy = *this; Options copy = *this;
@ -38,6 +42,12 @@ class MarkForCompilationPassTestHelper {
copy.disable_deadness_analysis = false; copy.disable_deadness_analysis = false;
return copy; return copy;
} }
Options WithNoClusterScoping() {
Options copy = *this;
copy.enable_cluster_scoping = false;
return copy;
}
}; };
// Runs the MarkForCompilation pass on `graph` after assigning all nodes in // Runs the MarkForCompilation pass on `graph` after assigning all nodes in