Change heuristics around creating small clusters
- Don't count "Const" nodes in cluster size - Bump up the default minimum cluster size to 4 These changes are motivated by some internal auto-clustering benchmarks. Some tests had to be updated to work with these new heuristics. I've tried to ensure that the tests keep testing what they intended to test. However, these changes do demonstrate a weakness in our tests; the tests not specifically testing our clustering heuristics should be more resilient to changes to our clustering heuristics. PiperOrigin-RevId: 238344897
This commit is contained in:
parent
825f04f9f2
commit
c94aab2f47
@ -630,7 +630,6 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"build_xla_ops_pass_test.cc",
|
||||
"clone_constants_for_better_clustering_test.cc",
|
||||
"compilation_passes_test_main.cc",
|
||||
"encapsulate_subgraphs_pass_test.cc",
|
||||
"encapsulate_xla_computations_pass_test.cc",
|
||||
"extract_outside_compilation_pass_test.cc",
|
||||
|
@ -38,10 +38,13 @@ GTEST_API_ int main(int real_argc, char** real_argv) {
|
||||
void operator()(char* ptr) { free(ptr); }
|
||||
};
|
||||
|
||||
std::unique_ptr<char, FreeDeleter> allocated_arg(
|
||||
std::unique_ptr<char, FreeDeleter> enable_global_jit_arg(
|
||||
strdup("--tf_xla_cpu_global_jit=true"));
|
||||
args.push_back(enable_global_jit_arg.get());
|
||||
|
||||
args.push_back(allocated_arg.get());
|
||||
std::unique_ptr<char, FreeDeleter> reduce_min_cluster_size_arg(
|
||||
strdup("--tf_xla_min_cluster_size=2"));
|
||||
args.push_back(reduce_min_cluster_size_arg.get());
|
||||
|
||||
int argc = args.size();
|
||||
|
||||
|
@ -73,7 +73,7 @@ void AllocateAndParseFlags() {
|
||||
|
||||
mark_for_compilation_flags = new MarkForCompilationPassFlags;
|
||||
mark_for_compilation_flags->tf_xla_auto_jit = 0;
|
||||
mark_for_compilation_flags->tf_xla_min_cluster_size = 2;
|
||||
mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
|
||||
mark_for_compilation_flags->tf_xla_max_cluster_size =
|
||||
std::numeric_limits<int32>::max();
|
||||
mark_for_compilation_flags->tf_xla_clustering_debug = false;
|
||||
|
@ -1263,13 +1263,24 @@ Status MarkForCompilationPass::RunImpl(
|
||||
|
||||
// Count the number of non-trivial elements in each cluster.
|
||||
std::vector<int> effective_cluster_sizes(graph->num_node_ids());
|
||||
|
||||
// has_functional_control_flow remembers if a cluster contains a functional
|
||||
// control flow node.
|
||||
std::vector<bool> has_functional_control_flow(graph->num_node_ids());
|
||||
|
||||
for (const Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].Get().representative;
|
||||
// Identity nodes will be removed if the node gets marked for compilation.
|
||||
// Therefore we don't want to count them towards the effective cluster size.
|
||||
if (n->def().op() != "Identity") {
|
||||
// We want clusters to be big enough that the benefit from XLA's
|
||||
// optimizations offsets XLA related overhead (for instance we add some
|
||||
// Switch/Merge nodes into the graph to implement lazy compilation). To
|
||||
// this end, we don't count Identity and Constant nodes because they do not
|
||||
// enable interesting optimizations by themselves.
|
||||
if (!n->IsIdentity() && !n->IsConstant()) {
|
||||
effective_cluster_sizes[cluster]++;
|
||||
}
|
||||
if (n->type_string() == "While" || n->type_string() == "If") {
|
||||
has_functional_control_flow[cluster] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Names for each cluster.
|
||||
@ -1312,11 +1323,13 @@ Status MarkForCompilationPass::RunImpl(
|
||||
marked_for_compilation = compile_attr;
|
||||
}
|
||||
|
||||
// Compile if this is a cluster of >= min_cluster_size compilable operators.
|
||||
// Also, always compile if it contains at least one op that is marked for
|
||||
// compilation that is not an Identity op.
|
||||
// We assume that functional If and While nodes have at least
|
||||
// min_cluster_size non-trivial nodes in them. It would be more principled
|
||||
// to (recursively) verify this fact, but that's probably not worth the
|
||||
// trouble.
|
||||
|
||||
if (effective_cluster_sizes[cluster_repr] >= min_cluster_size ||
|
||||
(effective_cluster_sizes[cluster_repr] > 0 && marked_for_compilation)) {
|
||||
has_functional_control_flow[cluster_repr] || marked_for_compilation) {
|
||||
string& name = cluster_names[cluster_repr];
|
||||
|
||||
if (name.empty()) {
|
||||
|
@ -195,35 +195,6 @@ TEST(XlaCompilationTest, HalfSupported) {
|
||||
EXPECT_FALSE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, ConcatWithConstArg) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
Tensor t(DT_INT32, TensorShape());
|
||||
t.scalar<int32>()() = 0;
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* dim = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("Dim")
|
||||
.WithAttr("dtype", DT_INT32)
|
||||
.WithAttr("value", t));
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", t));
|
||||
|
||||
NodeBuilder concat_builder("Concat", "Concat",
|
||||
builder.opts().op_registry());
|
||||
concat_builder.Input(dim).Input({a, a}).Attr("N", 2);
|
||||
builder.opts().FinalizeBuilder(&concat_builder);
|
||||
|
||||
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
auto clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, FunctionCalls) {
|
||||
FunctionDef compilable = FunctionDefHelper::Define(
|
||||
"CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
|
||||
@ -606,10 +577,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
|
||||
GetClusterSets(*graph);
|
||||
ASSERT_EQ(cluster_sets.size(), 1);
|
||||
std::vector<string> expected_clustered_nodes = {"AssignmentW",
|
||||
"ValueToAssignW"};
|
||||
ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
|
||||
ASSERT_EQ(cluster_sets.size(), 0);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, ChainOfOps) {
|
||||
@ -637,15 +605,11 @@ TEST(XlaCompilationTest, ChainOfOps) {
|
||||
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
|
||||
GetClusterSets(*graph, &cluster_names);
|
||||
|
||||
ASSERT_EQ(cluster_sets.size(), 2);
|
||||
ASSERT_EQ(cluster_sets.size(), 1);
|
||||
|
||||
std::vector<string> expected_clustered_nodes_a = {"AssignmentW0", "ConstN0",
|
||||
"ValueToAssignW0"};
|
||||
ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
|
||||
|
||||
std::vector<string> expected_clustered_nodes_b = {
|
||||
std::vector<string> expected_clustered_nodes_a = {
|
||||
"AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
|
||||
ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b);
|
||||
ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
|
||||
@ -704,9 +668,7 @@ TEST(XlaCompilationTest, Retval) {
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_EQ(2, clusters.size());
|
||||
EXPECT_TRUE(clusters.find("R") == clusters.cend());
|
||||
EXPECT_EQ(clusters["A"], clusters["B"]);
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, DontCountIdentityOps) {
|
||||
@ -725,22 +687,6 @@ TEST(XlaCompilationTest, DontCountIdentityOps) {
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
{
|
||||
auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
|
||||
auto b = ops::Identity(root.WithOpName("B"), a);
|
||||
b.node()->AddAttr(kXlaCompileAttr, true);
|
||||
auto r = ops::_Retval(root.WithOpName("R"), b, 0);
|
||||
}
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, ConstOp) {
|
||||
// valid data type
|
||||
{
|
||||
@ -996,8 +942,10 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) {
|
||||
absl::string_view xla_gpu_dev1 =
|
||||
"/job:worker/replica:0/task:0/device:XLA_GPU:1";
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
|
||||
Output b = ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2});
|
||||
Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
|
||||
ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
|
||||
Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
|
||||
ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
|
||||
Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
|
||||
Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
|
||||
|
||||
|
@ -513,9 +513,10 @@ class ElementWiseFusionTest(test.TestCase):
|
||||
def testElementWiseClustering(self):
|
||||
arg0 = np.random.rand(2, 2).astype(np.float32)
|
||||
arg1 = np.random.rand(2, 2).astype(np.float32)
|
||||
os.environ["TF_XLA_FLAGS"] = (
|
||||
"--tf_xla_fusion_only=true "
|
||||
"--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
|
||||
old_tf_xla_flags = os.environ.get("TF_XLA_FLAGS", "")
|
||||
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
|
||||
"--tf_xla_min_cluster_size=2 "
|
||||
"--tf_xla_cpu_global_jit " + old_tf_xla_flags)
|
||||
tf_op, tf_count = self.simpleTest(arg0, arg1,
|
||||
config_pb2.OptimizerOptions.OFF)
|
||||
self.assertEqual(0, tf_count)
|
||||
@ -525,6 +526,7 @@ class ElementWiseFusionTest(test.TestCase):
|
||||
self.assertEqual(2, tfef_count)
|
||||
|
||||
self.assertAllClose(tf_op, tfef_op, rtol=1e-1)
|
||||
os.environ["TF_XLA_FLAGS"] = old_tf_xla_flags
|
||||
|
||||
|
||||
class LazyCompilationTest(test.TestCase):
|
||||
|
@ -36,7 +36,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _random_types(self):
|
||||
return set(self.numeric_types) - set(
|
||||
self.complex_types) - {np.uint8, np.int8}
|
||||
self.complex_types) - {np.uint64, np.int64, np.uint8, np.int8}
|
||||
|
||||
def _testRngIsNotConstant(self, rng, dtype):
|
||||
# Tests that 'rng' does not always return the same value.
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
@ -1073,4 +1074,6 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
self.assertEqual(size1_v, 4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
|
||||
os.environ.get("TF_XLA_FLAGS", ""))
|
||||
test.main()
|
||||
|
@ -18,6 +18,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -211,4 +212,6 @@ class ListOpsTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual(t, [0., 0., 0.])
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' +
|
||||
os.environ.get('TF_XLA_FLAGS', ''))
|
||||
test.main()
|
||||
|
@ -246,4 +246,6 @@ def is_compile_on_demand():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
|
||||
os.environ.get("TF_XLA_FLAGS", ""))
|
||||
test.main()
|
||||
|
@ -341,8 +341,8 @@ class LSTMV2Test(keras_parameterized.TestCase):
|
||||
cudnn_model.fit(x_train, y_train)
|
||||
y_4 = cudnn_model.predict(x_train)
|
||||
|
||||
self.assertAllClose(y_1, y_3, rtol=1e-5, atol=1e-5)
|
||||
self.assertAllClose(y_2, y_4, rtol=1e-5, atol=1e-5)
|
||||
self.assertAllClose(y_1, y_3, rtol=1e-5, atol=2e-5)
|
||||
self.assertAllClose(y_2, y_4, rtol=1e-5, atol=2e-5)
|
||||
|
||||
@parameterized.named_parameters(('v0', 0), ('v1', 1), ('v2', 2))
|
||||
def DISABLED_test_implementation_mode_LSTM(self, implementation_mode):
|
||||
|
Loading…
Reference in New Issue
Block a user