From c94aab2f47f82bcb2622451312b7fe6eef0016a2 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 13 Mar 2019 17:26:33 -0700 Subject: [PATCH] Change heuristics around creating small clusters - Don't count "Const" nodes in cluster size - Bump up the default minimum cluster size to 4 These changes are motivated by some internal auto-clustering benchmarks. Some tests had to be updated to work with these new heuristics. I've tried to ensure that the tests keep testing what they intended to test. However, these changes do demonstrate a weakness in our tests; the tests not specifically testing our clustering heuristics should be more resilient to changes to our clustering heuristics. PiperOrigin-RevId: 238344897 --- tensorflow/compiler/jit/BUILD | 1 - .../jit/compilation_passes_test_main.cc | 7 +- tensorflow/compiler/jit/flags.cc | 2 +- .../compiler/jit/mark_for_compilation_pass.cc | 27 +++++-- .../jit/mark_for_compilation_pass_test.cc | 70 +++---------------- tensorflow/compiler/tests/jit_test.py | 8 ++- tensorflow/compiler/tests/random_ops_test.py | 2 +- .../compiler/tests/tensor_array_ops_test.py | 3 + .../compiler/tests/tensor_list_ops_test.py | 3 + tensorflow/compiler/tests/while_test.py | 2 + .../python/keras/layers/lstm_v2_test.py | 4 +- 11 files changed, 51 insertions(+), 78 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index ba13e782bd0..1de1e2f78df 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -630,7 +630,6 @@ tf_cc_test( srcs = [ "build_xla_ops_pass_test.cc", "clone_constants_for_better_clustering_test.cc", - "compilation_passes_test_main.cc", "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", "extract_outside_compilation_pass_test.cc", diff --git a/tensorflow/compiler/jit/compilation_passes_test_main.cc b/tensorflow/compiler/jit/compilation_passes_test_main.cc index 4b5c26faeaf..c73702fa642 100644 --- a/tensorflow/compiler/jit/compilation_passes_test_main.cc +++ b/tensorflow/compiler/jit/compilation_passes_test_main.cc @@ -38,10 +38,13 @@ GTEST_API_ int main(int real_argc, char** real_argv) { void operator()(char* ptr) { free(ptr); } }; - std::unique_ptr allocated_arg( + std::unique_ptr enable_global_jit_arg( strdup("--tf_xla_cpu_global_jit=true")); + args.push_back(enable_global_jit_arg.get()); - args.push_back(allocated_arg.get()); + std::unique_ptr reduce_min_cluster_size_arg( + strdup("--tf_xla_min_cluster_size=2")); + args.push_back(reduce_min_cluster_size_arg.get()); int argc = args.size(); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 8838070c3da..7fcf2b42e43 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -73,7 +73,7 @@ void AllocateAndParseFlags() { mark_for_compilation_flags = new MarkForCompilationPassFlags; mark_for_compilation_flags->tf_xla_auto_jit = 0; - mark_for_compilation_flags->tf_xla_min_cluster_size = 2; + mark_for_compilation_flags->tf_xla_min_cluster_size = 4; mark_for_compilation_flags->tf_xla_max_cluster_size = std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_clustering_debug = false; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index a49bf585c6a..80b3b596763 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1263,13 +1263,24 @@ Status MarkForCompilationPass::RunImpl( // Count the number of non-trivial elements in each cluster. std::vector effective_cluster_sizes(graph->num_node_ids()); + + // has_functional_control_flow remembers if a cluster contains a functional + // control flow node. + std::vector has_functional_control_flow(graph->num_node_ids()); + for (const Node* n : compilation_candidates) { int cluster = clusters[n->id()].Get().representative; - // Identity nodes will be removed if the node gets marked for compilation. - // Therefore we don't want to count them towards the effective cluster size. - if (n->def().op() != "Identity") { + // We want clusters to be big enough that the benefit from XLA's + // optimizations offsets XLA related overhead (for instance we add some + // Switch/Merge nodes into the graph to implement lazy compilation). To + // this end, we don't count Identity and Constant nodes because they do not + // enable interesting optimizations by themselves. + if (!n->IsIdentity() && !n->IsConstant()) { effective_cluster_sizes[cluster]++; } + if (n->type_string() == "While" || n->type_string() == "If") { + has_functional_control_flow[cluster] = true; + } } // Names for each cluster. @@ -1312,11 +1323,13 @@ Status MarkForCompilationPass::RunImpl( marked_for_compilation = compile_attr; } - // Compile if this is a cluster of >= min_cluster_size compilable operators. - // Also, always compile if it contains at least one op that is marked for - // compilation that is not an Identity op. + // We assume that functional If and While nodes have at least + // min_cluster_size non-trivial nodes in them. It would be more principled + // to (recursively) verify this fact, but that's probably not worth the + // trouble. + if (effective_cluster_sizes[cluster_repr] >= min_cluster_size || - (effective_cluster_sizes[cluster_repr] > 0 && marked_for_compilation)) { + has_functional_control_flow[cluster_repr] || marked_for_compilation) { string& name = cluster_names[cluster_repr]; if (name.empty()) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index f91ce59ad2b..da0fbf35de5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -195,35 +195,6 @@ TEST(XlaCompilationTest, HalfSupported) { EXPECT_FALSE(clusters.empty()); } -TEST(XlaCompilationTest, ConcatWithConstArg) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - GraphDef graphdef; - { - Tensor t(DT_INT32, TensorShape()); - t.scalar()() = 0; - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* dim = ops::SourceOp("Const", builder.opts() - .WithName("Dim") - .WithAttr("dtype", DT_INT32) - .WithAttr("value", t)); - Node* a = ops::SourceOp("Const", builder.opts() - .WithName("A") - .WithAttr("dtype", DT_FLOAT) - .WithAttr("value", t)); - - NodeBuilder concat_builder("Concat", "Concat", - builder.opts().op_registry()); - concat_builder.Input(dim).Input({a, a}).Attr("N", 2); - builder.opts().FinalizeBuilder(&concat_builder); - - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - } - - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - EXPECT_EQ(3, clusters.size()); // Everything should be compiled. -} - TEST(XlaCompilationTest, FunctionCalls) { FunctionDef compilable = FunctionDefHelper::Define( "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, @@ -606,10 +577,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); - ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes = {"AssignmentW", - "ValueToAssignW"}; - ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); + ASSERT_EQ(cluster_sets.size(), 0); } TEST(XlaCompilationTest, ChainOfOps) { @@ -637,15 +605,11 @@ TEST(XlaCompilationTest, ChainOfOps) { absl::flat_hash_map> cluster_sets = GetClusterSets(*graph, &cluster_names); - ASSERT_EQ(cluster_sets.size(), 2); + ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes_a = {"AssignmentW0", "ConstN0", - "ValueToAssignW0"}; - ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); - - std::vector expected_clustered_nodes_b = { + std::vector expected_clustered_nodes_a = { "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; - ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b); + ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); } TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { @@ -704,9 +668,7 @@ TEST(XlaCompilationTest, Retval) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); - EXPECT_EQ(2, clusters.size()); - EXPECT_TRUE(clusters.find("R") == clusters.cend()); - EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_TRUE(clusters.empty()); } TEST(XlaCompilationTest, DontCountIdentityOps) { @@ -725,22 +687,6 @@ TEST(XlaCompilationTest, DontCountIdentityOps) { EXPECT_TRUE(clusters.empty()); } -TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - Scope root = Scope::NewRootScope().ExitOnError(); - { - auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); - auto b = ops::Identity(root.WithOpName("B"), a); - b.node()->AddAttr(kXlaCompileAttr, true); - auto r = ops::_Retval(root.WithOpName("R"), b, 0); - } - TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - - EXPECT_TRUE(clusters.empty()); -} - TEST(XlaCompilationTest, ConstOp) { // valid data type { @@ -996,8 +942,10 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) { absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:XLA_GPU:1"; std::unique_ptr graph(new Graph(OpRegistry::Global())); - Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}); - Output b = ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}); + Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"), + ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2})); + Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"), + ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2})); Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a); Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b); diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index dbea9849e21..777a1562980 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -513,9 +513,10 @@ class ElementWiseFusionTest(test.TestCase): def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ( - "--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) + old_tf_xla_flags = os.environ.get("TF_XLA_FLAGS", "") + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " + "--tf_xla_min_cluster_size=2 " + "--tf_xla_cpu_global_jit " + old_tf_xla_flags) tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) @@ -525,6 +526,7 @@ class ElementWiseFusionTest(test.TestCase): self.assertEqual(2, tfef_count) self.assertAllClose(tf_op, tfef_op, rtol=1e-1) + os.environ["TF_XLA_FLAGS"] = old_tf_xla_flags class LazyCompilationTest(test.TestCase): diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 34f2465ba63..0611d6749fa 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -36,7 +36,7 @@ class RandomOpsTest(xla_test.XLATestCase): def _random_types(self): return set(self.numeric_types) - set( - self.complex_types) - {np.uint8, np.int8} + self.complex_types) - {np.uint64, np.int64, np.uint8, np.int8} def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index d2e715025cf..e64aa26cd4b 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.compiler.tests import xla_test @@ -1073,4 +1074,6 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertEqual(size1_v, 4) if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 3c0c36d0c4d..e07b150d601 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op @@ -211,4 +212,6 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(t, [0., 0., 0.]) if __name__ == "__main__": + os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' + + os.environ.get('TF_XLA_FLAGS', '')) test.main() diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index 84237ed2186..15a31111cb6 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -246,4 +246,6 @@ def is_compile_on_demand(): if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/python/keras/layers/lstm_v2_test.py b/tensorflow/python/keras/layers/lstm_v2_test.py index 5bafb56ba2d..26f588e3d2e 100644 --- a/tensorflow/python/keras/layers/lstm_v2_test.py +++ b/tensorflow/python/keras/layers/lstm_v2_test.py @@ -341,8 +341,8 @@ class LSTMV2Test(keras_parameterized.TestCase): cudnn_model.fit(x_train, y_train) y_4 = cudnn_model.predict(x_train) - self.assertAllClose(y_1, y_3, rtol=1e-5, atol=1e-5) - self.assertAllClose(y_2, y_4, rtol=1e-5, atol=1e-5) + self.assertAllClose(y_1, y_3, rtol=1e-5, atol=2e-5) + self.assertAllClose(y_2, y_4, rtol=1e-5, atol=2e-5) @parameterized.named_parameters(('v0', 0), ('v1', 1), ('v2', 2)) def DISABLED_test_implementation_mode_LSTM(self, implementation_mode):