Add a new unittest in mark_for_compilation_pass_test.
The new test tests that ClusterScopingPass works and MarkForCompilationPass accordinly preserves the required clustering scopes.
This commit is contained in:
parent
230cbd8568
commit
50429bdff1
@ -1718,5 +1718,91 @@ TEST(XlaCompilationTest, UnsupportedEnterExitPattern) {
|
|||||||
EXPECT_EQ(0, clusters.size());
|
EXPECT_EQ(0, clusters.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
Node* MakeStageNode(GraphDefBuilder& builder, string name,
|
||||||
|
std::initializer_list<DataType> dtypes,
|
||||||
|
gtl::ArraySlice<ops::NodeOut> values) {
|
||||||
|
auto opts = builder.opts()
|
||||||
|
.WithName(std::move(name))
|
||||||
|
.WithAttr("dtypes", std::move(dtypes));
|
||||||
|
if (opts.HaveError()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeBuilder node_builder(name, "Stage", opts.op_registry());
|
||||||
|
node_builder.Input(values);
|
||||||
|
return opts.FinalizeBuilder(&node_builder);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
|
||||||
|
auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status {
|
||||||
|
// Construct a graph as below with two pipeline stages and test that nodes
|
||||||
|
// in different stages will not be merged if ClusterScopingPass is on.
|
||||||
|
//
|
||||||
|
// b
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// a -> add0 -> relu0 -> stage
|
||||||
|
//
|
||||||
|
// b
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// unstage -> add1 -> relu1
|
||||||
|
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||||
|
Node* a = ops::SourceOp("Const", builder.opts()
|
||||||
|
.WithName("a")
|
||||||
|
.WithAttr("dtype", DT_FLOAT)
|
||||||
|
.WithAttr("value", Tensor()));
|
||||||
|
Node* b = ops::SourceOp("Const", builder.opts()
|
||||||
|
.WithName("b")
|
||||||
|
.WithAttr("dtype", DT_FLOAT)
|
||||||
|
.WithAttr("value", Tensor()));
|
||||||
|
Node* unstage = ops::SourceOp(
|
||||||
|
"Unstage",
|
||||||
|
builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
|
||||||
|
|
||||||
|
Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
|
||||||
|
Node* add1 =
|
||||||
|
ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
|
||||||
|
Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
|
||||||
|
ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
|
||||||
|
MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0});
|
||||||
|
|
||||||
|
return GraphDefBuilderToGraph(builder, graph->get());
|
||||||
|
};
|
||||||
|
|
||||||
|
// All nodes go into the same cluster if ClusterScopingPass is off.
|
||||||
|
{
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(build_staged_graph(&graph));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||||
|
&graph,
|
||||||
|
MarkForCompilationPassTestHelper::Options().WithNoClusterScoping()));
|
||||||
|
|
||||||
|
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||||
|
EXPECT_EQ(clusters["add0"], clusters["add1"]);
|
||||||
|
EXPECT_EQ(clusters["add0"], clusters["relu1"]);
|
||||||
|
EXPECT_EQ(clusters["relu0"], clusters["add1"]);
|
||||||
|
EXPECT_EQ(clusters["relu0"], clusters["relu1"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// By default, ClusterScopingPass is on and different pipeline stages should
|
||||||
|
// not be merged.
|
||||||
|
{
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(build_staged_graph(&graph));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
|
|
||||||
|
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||||
|
EXPECT_NE(clusters["add0"], clusters["add1"]);
|
||||||
|
EXPECT_NE(clusters["add0"], clusters["relu1"]);
|
||||||
|
EXPECT_NE(clusters["relu0"], clusters["add1"]);
|
||||||
|
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
|
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
@ -48,8 +50,14 @@ namespace tensorflow {
|
|||||||
opt_options.graph = graph;
|
opt_options.graph = graph;
|
||||||
opt_options.session_options = &session_options;
|
opt_options.session_options = &session_options;
|
||||||
opt_options.flib_def = flib_def;
|
opt_options.flib_def = flib_def;
|
||||||
MarkForCompilationPass pass;
|
|
||||||
return pass.RunForTest(
|
if (options.enable_cluster_scoping) {
|
||||||
|
ClusterScopingPass cluster_scoping_pass;
|
||||||
|
TF_RETURN_IF_ERROR(cluster_scoping_pass.Run(opt_options));
|
||||||
|
}
|
||||||
|
|
||||||
|
MarkForCompilationPass mark_for_compilation_pass;
|
||||||
|
return mark_for_compilation_pass.RunForTest(
|
||||||
opt_options,
|
opt_options,
|
||||||
/*disable_deadness_analysis=*/options.disable_deadness_analysis);
|
/*disable_deadness_analysis=*/options.disable_deadness_analysis);
|
||||||
}
|
}
|
||||||
|
@ -24,8 +24,12 @@ class MarkForCompilationPassTestHelper {
|
|||||||
struct Options {
|
struct Options {
|
||||||
bool enable_global_jit;
|
bool enable_global_jit;
|
||||||
bool disable_deadness_analysis;
|
bool disable_deadness_analysis;
|
||||||
|
bool enable_cluster_scoping;
|
||||||
|
|
||||||
Options() : enable_global_jit(true), disable_deadness_analysis(true) {}
|
Options()
|
||||||
|
: enable_global_jit(true),
|
||||||
|
disable_deadness_analysis(true),
|
||||||
|
enable_cluster_scoping(true) {}
|
||||||
|
|
||||||
Options WithNoGlobalJit() {
|
Options WithNoGlobalJit() {
|
||||||
Options copy = *this;
|
Options copy = *this;
|
||||||
@ -38,6 +42,12 @@ class MarkForCompilationPassTestHelper {
|
|||||||
copy.disable_deadness_analysis = false;
|
copy.disable_deadness_analysis = false;
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Options WithNoClusterScoping() {
|
||||||
|
Options copy = *this;
|
||||||
|
copy.enable_cluster_scoping = false;
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Runs the MarkForCompilation pass on `graph` after assigning all nodes in
|
// Runs the MarkForCompilation pass on `graph` after assigning all nodes in
|
||||||
|
Loading…
Reference in New Issue
Block a user