Change heuristics around creating small clusters
- Don't count "Const" nodes in cluster size - Bump up the default minimum cluster size to 4 These changes are motivated by some internal auto-clustering benchmarks. Some tests had to be updated to work with these new heuristics. I've tried to ensure that the tests keep testing what they intended to test. However, these changes do demonstrate a weakness in our tests; the tests not specifically testing our clustering heuristics should be more resilient to changes to our clustering heuristics. PiperOrigin-RevId: 238344897
This commit is contained in:
parent
825f04f9f2
commit
c94aab2f47
@ -630,7 +630,6 @@ tf_cc_test(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"build_xla_ops_pass_test.cc",
|
"build_xla_ops_pass_test.cc",
|
||||||
"clone_constants_for_better_clustering_test.cc",
|
"clone_constants_for_better_clustering_test.cc",
|
||||||
"compilation_passes_test_main.cc",
|
|
||||||
"encapsulate_subgraphs_pass_test.cc",
|
"encapsulate_subgraphs_pass_test.cc",
|
||||||
"encapsulate_xla_computations_pass_test.cc",
|
"encapsulate_xla_computations_pass_test.cc",
|
||||||
"extract_outside_compilation_pass_test.cc",
|
"extract_outside_compilation_pass_test.cc",
|
||||||
|
@ -38,10 +38,13 @@ GTEST_API_ int main(int real_argc, char** real_argv) {
|
|||||||
void operator()(char* ptr) { free(ptr); }
|
void operator()(char* ptr) { free(ptr); }
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<char, FreeDeleter> allocated_arg(
|
std::unique_ptr<char, FreeDeleter> enable_global_jit_arg(
|
||||||
strdup("--tf_xla_cpu_global_jit=true"));
|
strdup("--tf_xla_cpu_global_jit=true"));
|
||||||
|
args.push_back(enable_global_jit_arg.get());
|
||||||
|
|
||||||
args.push_back(allocated_arg.get());
|
std::unique_ptr<char, FreeDeleter> reduce_min_cluster_size_arg(
|
||||||
|
strdup("--tf_xla_min_cluster_size=2"));
|
||||||
|
args.push_back(reduce_min_cluster_size_arg.get());
|
||||||
|
|
||||||
int argc = args.size();
|
int argc = args.size();
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ void AllocateAndParseFlags() {
|
|||||||
|
|
||||||
mark_for_compilation_flags = new MarkForCompilationPassFlags;
|
mark_for_compilation_flags = new MarkForCompilationPassFlags;
|
||||||
mark_for_compilation_flags->tf_xla_auto_jit = 0;
|
mark_for_compilation_flags->tf_xla_auto_jit = 0;
|
||||||
mark_for_compilation_flags->tf_xla_min_cluster_size = 2;
|
mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
|
||||||
mark_for_compilation_flags->tf_xla_max_cluster_size =
|
mark_for_compilation_flags->tf_xla_max_cluster_size =
|
||||||
std::numeric_limits<int32>::max();
|
std::numeric_limits<int32>::max();
|
||||||
mark_for_compilation_flags->tf_xla_clustering_debug = false;
|
mark_for_compilation_flags->tf_xla_clustering_debug = false;
|
||||||
|
@ -1263,13 +1263,24 @@ Status MarkForCompilationPass::RunImpl(
|
|||||||
|
|
||||||
// Count the number of non-trivial elements in each cluster.
|
// Count the number of non-trivial elements in each cluster.
|
||||||
std::vector<int> effective_cluster_sizes(graph->num_node_ids());
|
std::vector<int> effective_cluster_sizes(graph->num_node_ids());
|
||||||
|
|
||||||
|
// has_functional_control_flow remembers if a cluster contains a functional
|
||||||
|
// control flow node.
|
||||||
|
std::vector<bool> has_functional_control_flow(graph->num_node_ids());
|
||||||
|
|
||||||
for (const Node* n : compilation_candidates) {
|
for (const Node* n : compilation_candidates) {
|
||||||
int cluster = clusters[n->id()].Get().representative;
|
int cluster = clusters[n->id()].Get().representative;
|
||||||
// Identity nodes will be removed if the node gets marked for compilation.
|
// We want clusters to be big enough that the benefit from XLA's
|
||||||
// Therefore we don't want to count them towards the effective cluster size.
|
// optimizations offsets XLA related overhead (for instance we add some
|
||||||
if (n->def().op() != "Identity") {
|
// Switch/Merge nodes into the graph to implement lazy compilation). To
|
||||||
|
// this end, we don't count Identity and Constant nodes because they do not
|
||||||
|
// enable interesting optimizations by themselves.
|
||||||
|
if (!n->IsIdentity() && !n->IsConstant()) {
|
||||||
effective_cluster_sizes[cluster]++;
|
effective_cluster_sizes[cluster]++;
|
||||||
}
|
}
|
||||||
|
if (n->type_string() == "While" || n->type_string() == "If") {
|
||||||
|
has_functional_control_flow[cluster] = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Names for each cluster.
|
// Names for each cluster.
|
||||||
@ -1312,11 +1323,13 @@ Status MarkForCompilationPass::RunImpl(
|
|||||||
marked_for_compilation = compile_attr;
|
marked_for_compilation = compile_attr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile if this is a cluster of >= min_cluster_size compilable operators.
|
// We assume that functional If and While nodes have at least
|
||||||
// Also, always compile if it contains at least one op that is marked for
|
// min_cluster_size non-trivial nodes in them. It would be more principled
|
||||||
// compilation that is not an Identity op.
|
// to (recursively) verify this fact, but that's probably not worth the
|
||||||
|
// trouble.
|
||||||
|
|
||||||
if (effective_cluster_sizes[cluster_repr] >= min_cluster_size ||
|
if (effective_cluster_sizes[cluster_repr] >= min_cluster_size ||
|
||||||
(effective_cluster_sizes[cluster_repr] > 0 && marked_for_compilation)) {
|
has_functional_control_flow[cluster_repr] || marked_for_compilation) {
|
||||||
string& name = cluster_names[cluster_repr];
|
string& name = cluster_names[cluster_repr];
|
||||||
|
|
||||||
if (name.empty()) {
|
if (name.empty()) {
|
||||||
|
@ -195,35 +195,6 @@ TEST(XlaCompilationTest, HalfSupported) {
|
|||||||
EXPECT_FALSE(clusters.empty());
|
EXPECT_FALSE(clusters.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(XlaCompilationTest, ConcatWithConstArg) {
|
|
||||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
|
||||||
GraphDef graphdef;
|
|
||||||
{
|
|
||||||
Tensor t(DT_INT32, TensorShape());
|
|
||||||
t.scalar<int32>()() = 0;
|
|
||||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
|
||||||
Node* dim = ops::SourceOp("Const", builder.opts()
|
|
||||||
.WithName("Dim")
|
|
||||||
.WithAttr("dtype", DT_INT32)
|
|
||||||
.WithAttr("value", t));
|
|
||||||
Node* a = ops::SourceOp("Const", builder.opts()
|
|
||||||
.WithName("A")
|
|
||||||
.WithAttr("dtype", DT_FLOAT)
|
|
||||||
.WithAttr("value", t));
|
|
||||||
|
|
||||||
NodeBuilder concat_builder("Concat", "Concat",
|
|
||||||
builder.opts().op_registry());
|
|
||||||
concat_builder.Input(dim).Input({a, a}).Attr("N", 2);
|
|
||||||
builder.opts().FinalizeBuilder(&concat_builder);
|
|
||||||
|
|
||||||
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
|
||||||
auto clusters = GetClusters(*graph);
|
|
||||||
EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(XlaCompilationTest, FunctionCalls) {
|
TEST(XlaCompilationTest, FunctionCalls) {
|
||||||
FunctionDef compilable = FunctionDefHelper::Define(
|
FunctionDef compilable = FunctionDefHelper::Define(
|
||||||
"CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
|
"CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
|
||||||
@ -606,10 +577,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
|
|||||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
|
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
|
||||||
GetClusterSets(*graph);
|
GetClusterSets(*graph);
|
||||||
ASSERT_EQ(cluster_sets.size(), 1);
|
ASSERT_EQ(cluster_sets.size(), 0);
|
||||||
std::vector<string> expected_clustered_nodes = {"AssignmentW",
|
|
||||||
"ValueToAssignW"};
|
|
||||||
ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(XlaCompilationTest, ChainOfOps) {
|
TEST(XlaCompilationTest, ChainOfOps) {
|
||||||
@ -637,15 +605,11 @@ TEST(XlaCompilationTest, ChainOfOps) {
|
|||||||
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
|
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
|
||||||
GetClusterSets(*graph, &cluster_names);
|
GetClusterSets(*graph, &cluster_names);
|
||||||
|
|
||||||
ASSERT_EQ(cluster_sets.size(), 2);
|
ASSERT_EQ(cluster_sets.size(), 1);
|
||||||
|
|
||||||
std::vector<string> expected_clustered_nodes_a = {"AssignmentW0", "ConstN0",
|
std::vector<string> expected_clustered_nodes_a = {
|
||||||
"ValueToAssignW0"};
|
|
||||||
ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
|
|
||||||
|
|
||||||
std::vector<string> expected_clustered_nodes_b = {
|
|
||||||
"AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
|
"AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
|
||||||
ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b);
|
ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
|
TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
|
||||||
@ -704,9 +668,7 @@ TEST(XlaCompilationTest, Retval) {
|
|||||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
auto clusters = GetClusters(*graph);
|
auto clusters = GetClusters(*graph);
|
||||||
|
|
||||||
EXPECT_EQ(2, clusters.size());
|
EXPECT_TRUE(clusters.empty());
|
||||||
EXPECT_TRUE(clusters.find("R") == clusters.cend());
|
|
||||||
EXPECT_EQ(clusters["A"], clusters["B"]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(XlaCompilationTest, DontCountIdentityOps) {
|
TEST(XlaCompilationTest, DontCountIdentityOps) {
|
||||||
@ -725,22 +687,6 @@ TEST(XlaCompilationTest, DontCountIdentityOps) {
|
|||||||
EXPECT_TRUE(clusters.empty());
|
EXPECT_TRUE(clusters.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) {
|
|
||||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
|
||||||
Scope root = Scope::NewRootScope().ExitOnError();
|
|
||||||
{
|
|
||||||
auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
|
|
||||||
auto b = ops::Identity(root.WithOpName("B"), a);
|
|
||||||
b.node()->AddAttr(kXlaCompileAttr, true);
|
|
||||||
auto r = ops::_Retval(root.WithOpName("R"), b, 0);
|
|
||||||
}
|
|
||||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
|
||||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
|
||||||
auto clusters = GetClusters(*graph);
|
|
||||||
|
|
||||||
EXPECT_TRUE(clusters.empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(XlaCompilationTest, ConstOp) {
|
TEST(XlaCompilationTest, ConstOp) {
|
||||||
// valid data type
|
// valid data type
|
||||||
{
|
{
|
||||||
@ -996,8 +942,10 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) {
|
|||||||
absl::string_view xla_gpu_dev1 =
|
absl::string_view xla_gpu_dev1 =
|
||||||
"/job:worker/replica:0/task:0/device:XLA_GPU:1";
|
"/job:worker/replica:0/task:0/device:XLA_GPU:1";
|
||||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
|
Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
|
||||||
Output b = ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2});
|
ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
|
||||||
|
Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
|
||||||
|
ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
|
||||||
Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
|
Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
|
||||||
Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
|
Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
|
||||||
|
|
||||||
|
@ -513,9 +513,10 @@ class ElementWiseFusionTest(test.TestCase):
|
|||||||
def testElementWiseClustering(self):
|
def testElementWiseClustering(self):
|
||||||
arg0 = np.random.rand(2, 2).astype(np.float32)
|
arg0 = np.random.rand(2, 2).astype(np.float32)
|
||||||
arg1 = np.random.rand(2, 2).astype(np.float32)
|
arg1 = np.random.rand(2, 2).astype(np.float32)
|
||||||
os.environ["TF_XLA_FLAGS"] = (
|
old_tf_xla_flags = os.environ.get("TF_XLA_FLAGS", "")
|
||||||
"--tf_xla_fusion_only=true "
|
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
|
||||||
"--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
|
"--tf_xla_min_cluster_size=2 "
|
||||||
|
"--tf_xla_cpu_global_jit " + old_tf_xla_flags)
|
||||||
tf_op, tf_count = self.simpleTest(arg0, arg1,
|
tf_op, tf_count = self.simpleTest(arg0, arg1,
|
||||||
config_pb2.OptimizerOptions.OFF)
|
config_pb2.OptimizerOptions.OFF)
|
||||||
self.assertEqual(0, tf_count)
|
self.assertEqual(0, tf_count)
|
||||||
@ -525,6 +526,7 @@ class ElementWiseFusionTest(test.TestCase):
|
|||||||
self.assertEqual(2, tfef_count)
|
self.assertEqual(2, tfef_count)
|
||||||
|
|
||||||
self.assertAllClose(tf_op, tfef_op, rtol=1e-1)
|
self.assertAllClose(tf_op, tfef_op, rtol=1e-1)
|
||||||
|
os.environ["TF_XLA_FLAGS"] = old_tf_xla_flags
|
||||||
|
|
||||||
|
|
||||||
class LazyCompilationTest(test.TestCase):
|
class LazyCompilationTest(test.TestCase):
|
||||||
|
@ -36,7 +36,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
def _random_types(self):
|
def _random_types(self):
|
||||||
return set(self.numeric_types) - set(
|
return set(self.numeric_types) - set(
|
||||||
self.complex_types) - {np.uint8, np.int8}
|
self.complex_types) - {np.uint64, np.int64, np.uint8, np.int8}
|
||||||
|
|
||||||
def _testRngIsNotConstant(self, rng, dtype):
|
def _testRngIsNotConstant(self, rng, dtype):
|
||||||
# Tests that 'rng' does not always return the same value.
|
# Tests that 'rng' does not always return the same value.
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
@ -1073,4 +1074,6 @@ class TensorArrayTest(xla_test.XLATestCase):
|
|||||||
self.assertEqual(size1_v, 4)
|
self.assertEqual(size1_v, 4)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
|
||||||
|
os.environ.get("TF_XLA_FLAGS", ""))
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -211,4 +212,6 @@ class ListOpsTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual(t, [0., 0., 0.])
|
self.assertAllEqual(t, [0., 0., 0.])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' +
|
||||||
|
os.environ.get('TF_XLA_FLAGS', ''))
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -246,4 +246,6 @@ def is_compile_on_demand():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
|
||||||
|
os.environ.get("TF_XLA_FLAGS", ""))
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -341,8 +341,8 @@ class LSTMV2Test(keras_parameterized.TestCase):
|
|||||||
cudnn_model.fit(x_train, y_train)
|
cudnn_model.fit(x_train, y_train)
|
||||||
y_4 = cudnn_model.predict(x_train)
|
y_4 = cudnn_model.predict(x_train)
|
||||||
|
|
||||||
self.assertAllClose(y_1, y_3, rtol=1e-5, atol=1e-5)
|
self.assertAllClose(y_1, y_3, rtol=1e-5, atol=2e-5)
|
||||||
self.assertAllClose(y_2, y_4, rtol=1e-5, atol=1e-5)
|
self.assertAllClose(y_2, y_4, rtol=1e-5, atol=2e-5)
|
||||||
|
|
||||||
@parameterized.named_parameters(('v0', 0), ('v1', 1), ('v2', 2))
|
@parameterized.named_parameters(('v0', 0), ('v1', 1), ('v2', 2))
|
||||||
def DISABLED_test_implementation_mode_LSTM(self, implementation_mode):
|
def DISABLED_test_implementation_mode_LSTM(self, implementation_mode):
|
||||||
|
Loading…
Reference in New Issue
Block a user