Change heuristics around creating small clusters

- Don't count "Const" nodes in cluster size
 - Bump up the default minimum cluster size to 4

These changes are motivated by some internal auto-clustering benchmarks.

Some tests had to be updated to work with these new heuristics.  I've tried to
ensure that the tests keep testing what they intended to test.

However, these changes do demonstrate a weakness in our tests; the tests not
specifically testing our clustering heuristics should be more resilient to
changes to our clustering heuristics.

PiperOrigin-RevId: 238344897
This commit is contained in:
Sanjoy Das 2019-03-13 17:26:33 -07:00 committed by TensorFlower Gardener
parent 825f04f9f2
commit c94aab2f47
11 changed files with 51 additions and 78 deletions

View File

@ -630,7 +630,6 @@ tf_cc_test(
srcs = [
"build_xla_ops_pass_test.cc",
"clone_constants_for_better_clustering_test.cc",
"compilation_passes_test_main.cc",
"encapsulate_subgraphs_pass_test.cc",
"encapsulate_xla_computations_pass_test.cc",
"extract_outside_compilation_pass_test.cc",

View File

@ -38,10 +38,13 @@ GTEST_API_ int main(int real_argc, char** real_argv) {
void operator()(char* ptr) { free(ptr); }
};
std::unique_ptr<char, FreeDeleter> allocated_arg(
std::unique_ptr<char, FreeDeleter> enable_global_jit_arg(
strdup("--tf_xla_cpu_global_jit=true"));
args.push_back(enable_global_jit_arg.get());
args.push_back(allocated_arg.get());
std::unique_ptr<char, FreeDeleter> reduce_min_cluster_size_arg(
strdup("--tf_xla_min_cluster_size=2"));
args.push_back(reduce_min_cluster_size_arg.get());
int argc = args.size();

View File

@ -73,7 +73,7 @@ void AllocateAndParseFlags() {
mark_for_compilation_flags = new MarkForCompilationPassFlags;
mark_for_compilation_flags->tf_xla_auto_jit = 0;
mark_for_compilation_flags->tf_xla_min_cluster_size = 2;
mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
mark_for_compilation_flags->tf_xla_max_cluster_size =
std::numeric_limits<int32>::max();
mark_for_compilation_flags->tf_xla_clustering_debug = false;

View File

@ -1263,13 +1263,24 @@ Status MarkForCompilationPass::RunImpl(
// Count the number of non-trivial elements in each cluster.
std::vector<int> effective_cluster_sizes(graph->num_node_ids());
// has_functional_control_flow remembers if a cluster contains a functional
// control flow node.
std::vector<bool> has_functional_control_flow(graph->num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// Identity nodes will be removed if the node gets marked for compilation.
// Therefore we don't want to count them towards the effective cluster size.
if (n->def().op() != "Identity") {
// We want clusters to be big enough that the benefit from XLA's
// optimizations offsets XLA related overhead (for instance we add some
// Switch/Merge nodes into the graph to implement lazy compilation). To
// this end, we don't count Identity and Constant nodes because they do not
// enable interesting optimizations by themselves.
if (!n->IsIdentity() && !n->IsConstant()) {
effective_cluster_sizes[cluster]++;
}
if (n->type_string() == "While" || n->type_string() == "If") {
has_functional_control_flow[cluster] = true;
}
}
// Names for each cluster.
@ -1312,11 +1323,13 @@ Status MarkForCompilationPass::RunImpl(
marked_for_compilation = compile_attr;
}
// Compile if this is a cluster of >= min_cluster_size compilable operators.
// Also, always compile if it contains at least one op that is marked for
// compilation that is not an Identity op.
// We assume that functional If and While nodes have at least
// min_cluster_size non-trivial nodes in them. It would be more principled
// to (recursively) verify this fact, but that's probably not worth the
// trouble.
if (effective_cluster_sizes[cluster_repr] >= min_cluster_size ||
(effective_cluster_sizes[cluster_repr] > 0 && marked_for_compilation)) {
has_functional_control_flow[cluster_repr] || marked_for_compilation) {
string& name = cluster_names[cluster_repr];
if (name.empty()) {

View File

@ -195,35 +195,6 @@ TEST(XlaCompilationTest, HalfSupported) {
EXPECT_FALSE(clusters.empty());
}
TEST(XlaCompilationTest, ConcatWithConstArg) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
Tensor t(DT_INT32, TensorShape());
t.scalar<int32>()() = 0;
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* dim = ops::SourceOp("Const", builder.opts()
.WithName("Dim")
.WithAttr("dtype", DT_INT32)
.WithAttr("value", t));
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", t));
NodeBuilder concat_builder("Concat", "Concat",
builder.opts().op_registry());
concat_builder.Input(dim).Input({a, a}).Attr("N", 2);
builder.opts().FinalizeBuilder(&concat_builder);
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
}
TEST(XlaCompilationTest, FunctionCalls) {
FunctionDef compilable = FunctionDefHelper::Define(
"CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
@ -606,10 +577,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph);
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes = {"AssignmentW",
"ValueToAssignW"};
ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
ASSERT_EQ(cluster_sets.size(), 0);
}
TEST(XlaCompilationTest, ChainOfOps) {
@ -637,15 +605,11 @@ TEST(XlaCompilationTest, ChainOfOps) {
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph, &cluster_names);
ASSERT_EQ(cluster_sets.size(), 2);
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes_a = {"AssignmentW0", "ConstN0",
"ValueToAssignW0"};
ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
std::vector<string> expected_clustered_nodes_b = {
std::vector<string> expected_clustered_nodes_a = {
"AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b);
ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
}
TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
@ -704,9 +668,7 @@ TEST(XlaCompilationTest, Retval) {
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
EXPECT_TRUE(clusters.find("R") == clusters.cend());
EXPECT_EQ(clusters["A"], clusters["B"]);
EXPECT_TRUE(clusters.empty());
}
TEST(XlaCompilationTest, DontCountIdentityOps) {
@ -725,22 +687,6 @@ TEST(XlaCompilationTest, DontCountIdentityOps) {
EXPECT_TRUE(clusters.empty());
}
TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Scope root = Scope::NewRootScope().ExitOnError();
{
auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
auto b = ops::Identity(root.WithOpName("B"), a);
b.node()->AddAttr(kXlaCompileAttr, true);
auto r = ops::_Retval(root.WithOpName("R"), b, 0);
}
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
}
TEST(XlaCompilationTest, ConstOp) {
// valid data type
{
@ -996,8 +942,10 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) {
absl::string_view xla_gpu_dev1 =
"/job:worker/replica:0/task:0/device:XLA_GPU:1";
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
Output b = ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2});
Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);

View File

@ -513,9 +513,10 @@ class ElementWiseFusionTest(test.TestCase):
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
arg1 = np.random.rand(2, 2).astype(np.float32)
os.environ["TF_XLA_FLAGS"] = (
"--tf_xla_fusion_only=true "
"--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
old_tf_xla_flags = os.environ.get("TF_XLA_FLAGS", "")
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
"--tf_xla_min_cluster_size=2 "
"--tf_xla_cpu_global_jit " + old_tf_xla_flags)
tf_op, tf_count = self.simpleTest(arg0, arg1,
config_pb2.OptimizerOptions.OFF)
self.assertEqual(0, tf_count)
@ -525,6 +526,7 @@ class ElementWiseFusionTest(test.TestCase):
self.assertEqual(2, tfef_count)
self.assertAllClose(tf_op, tfef_op, rtol=1e-1)
os.environ["TF_XLA_FLAGS"] = old_tf_xla_flags
class LazyCompilationTest(test.TestCase):

View File

@ -36,7 +36,7 @@ class RandomOpsTest(xla_test.XLATestCase):
def _random_types(self):
return set(self.numeric_types) - set(
self.complex_types) - {np.uint8, np.int8}
self.complex_types) - {np.uint64, np.int64, np.uint8, np.int8}
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.compiler.tests import xla_test
@ -1073,4 +1074,6 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertEqual(size1_v, 4)
if __name__ == "__main__":
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
os.environ.get("TF_XLA_FLAGS", ""))
test.main()

View File

@ -18,6 +18,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
@ -211,4 +212,6 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(t, [0., 0., 0.])
if __name__ == "__main__":
os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' +
os.environ.get('TF_XLA_FLAGS', ''))
test.main()

View File

@ -246,4 +246,6 @@ def is_compile_on_demand():
if __name__ == "__main__":
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
os.environ.get("TF_XLA_FLAGS", ""))
test.main()

View File

@ -341,8 +341,8 @@ class LSTMV2Test(keras_parameterized.TestCase):
cudnn_model.fit(x_train, y_train)
y_4 = cudnn_model.predict(x_train)
self.assertAllClose(y_1, y_3, rtol=1e-5, atol=1e-5)
self.assertAllClose(y_2, y_4, rtol=1e-5, atol=1e-5)
self.assertAllClose(y_1, y_3, rtol=1e-5, atol=2e-5)
self.assertAllClose(y_2, y_4, rtol=1e-5, atol=2e-5)
@parameterized.named_parameters(('v0', 0), ('v1', 1), ('v2', 2))
def DISABLED_test_implementation_mode_LSTM(self, implementation_mode):