diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index ba13e782bd0..1de1e2f78df 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -630,7 +630,6 @@ tf_cc_test( srcs = [ "build_xla_ops_pass_test.cc", "clone_constants_for_better_clustering_test.cc", - "compilation_passes_test_main.cc", "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", "extract_outside_compilation_pass_test.cc", diff --git a/tensorflow/compiler/jit/compilation_passes_test_main.cc b/tensorflow/compiler/jit/compilation_passes_test_main.cc index 4b5c26faeaf..c73702fa642 100644 --- a/tensorflow/compiler/jit/compilation_passes_test_main.cc +++ b/tensorflow/compiler/jit/compilation_passes_test_main.cc @@ -38,10 +38,13 @@ GTEST_API_ int main(int real_argc, char** real_argv) { void operator()(char* ptr) { free(ptr); } }; - std::unique_ptr allocated_arg( + std::unique_ptr enable_global_jit_arg( strdup("--tf_xla_cpu_global_jit=true")); + args.push_back(enable_global_jit_arg.get()); - args.push_back(allocated_arg.get()); + std::unique_ptr reduce_min_cluster_size_arg( + strdup("--tf_xla_min_cluster_size=2")); + args.push_back(reduce_min_cluster_size_arg.get()); int argc = args.size(); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 8838070c3da..7fcf2b42e43 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -73,7 +73,7 @@ void AllocateAndParseFlags() { mark_for_compilation_flags = new MarkForCompilationPassFlags; mark_for_compilation_flags->tf_xla_auto_jit = 0; - mark_for_compilation_flags->tf_xla_min_cluster_size = 2; + mark_for_compilation_flags->tf_xla_min_cluster_size = 4; mark_for_compilation_flags->tf_xla_max_cluster_size = std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_clustering_debug = false; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index a49bf585c6a..80b3b596763 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1263,13 +1263,24 @@ Status MarkForCompilationPass::RunImpl( // Count the number of non-trivial elements in each cluster. std::vector effective_cluster_sizes(graph->num_node_ids()); + + // has_functional_control_flow remembers if a cluster contains a functional + // control flow node. + std::vector has_functional_control_flow(graph->num_node_ids()); + for (const Node* n : compilation_candidates) { int cluster = clusters[n->id()].Get().representative; - // Identity nodes will be removed if the node gets marked for compilation. - // Therefore we don't want to count them towards the effective cluster size. - if (n->def().op() != "Identity") { + // We want clusters to be big enough that the benefit from XLA's + // optimizations offsets XLA related overhead (for instance we add some + // Switch/Merge nodes into the graph to implement lazy compilation). To + // this end, we don't count Identity and Constant nodes because they do not + // enable interesting optimizations by themselves. + if (!n->IsIdentity() && !n->IsConstant()) { effective_cluster_sizes[cluster]++; } + if (n->type_string() == "While" || n->type_string() == "If") { + has_functional_control_flow[cluster] = true; + } } // Names for each cluster. @@ -1312,11 +1323,13 @@ Status MarkForCompilationPass::RunImpl( marked_for_compilation = compile_attr; } - // Compile if this is a cluster of >= min_cluster_size compilable operators. - // Also, always compile if it contains at least one op that is marked for - // compilation that is not an Identity op. + // We assume that functional If and While nodes have at least + // min_cluster_size non-trivial nodes in them. It would be more principled + // to (recursively) verify this fact, but that's probably not worth the + // trouble. + if (effective_cluster_sizes[cluster_repr] >= min_cluster_size || - (effective_cluster_sizes[cluster_repr] > 0 && marked_for_compilation)) { + has_functional_control_flow[cluster_repr] || marked_for_compilation) { string& name = cluster_names[cluster_repr]; if (name.empty()) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index f91ce59ad2b..da0fbf35de5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -195,35 +195,6 @@ TEST(XlaCompilationTest, HalfSupported) { EXPECT_FALSE(clusters.empty()); } -TEST(XlaCompilationTest, ConcatWithConstArg) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - GraphDef graphdef; - { - Tensor t(DT_INT32, TensorShape()); - t.scalar()() = 0; - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* dim = ops::SourceOp("Const", builder.opts() - .WithName("Dim") - .WithAttr("dtype", DT_INT32) - .WithAttr("value", t)); - Node* a = ops::SourceOp("Const", builder.opts() - .WithName("A") - .WithAttr("dtype", DT_FLOAT) - .WithAttr("value", t)); - - NodeBuilder concat_builder("Concat", "Concat", - builder.opts().op_registry()); - concat_builder.Input(dim).Input({a, a}).Attr("N", 2); - builder.opts().FinalizeBuilder(&concat_builder); - - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - } - - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - EXPECT_EQ(3, clusters.size()); // Everything should be compiled. -} - TEST(XlaCompilationTest, FunctionCalls) { FunctionDef compilable = FunctionDefHelper::Define( "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, @@ -606,10 +577,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); - ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes = {"AssignmentW", - "ValueToAssignW"}; - ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); + ASSERT_EQ(cluster_sets.size(), 0); } TEST(XlaCompilationTest, ChainOfOps) { @@ -637,15 +605,11 @@ TEST(XlaCompilationTest, ChainOfOps) { absl::flat_hash_map> cluster_sets = GetClusterSets(*graph, &cluster_names); - ASSERT_EQ(cluster_sets.size(), 2); + ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes_a = {"AssignmentW0", "ConstN0", - "ValueToAssignW0"}; - ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); - - std::vector expected_clustered_nodes_b = { + std::vector expected_clustered_nodes_a = { "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; - ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b); + ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); } TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { @@ -704,9 +668,7 @@ TEST(XlaCompilationTest, Retval) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); - EXPECT_EQ(2, clusters.size()); - EXPECT_TRUE(clusters.find("R") == clusters.cend()); - EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_TRUE(clusters.empty()); } TEST(XlaCompilationTest, DontCountIdentityOps) { @@ -725,22 +687,6 @@ TEST(XlaCompilationTest, DontCountIdentityOps) { EXPECT_TRUE(clusters.empty()); } -TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - Scope root = Scope::NewRootScope().ExitOnError(); - { - auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); - auto b = ops::Identity(root.WithOpName("B"), a); - b.node()->AddAttr(kXlaCompileAttr, true); - auto r = ops::_Retval(root.WithOpName("R"), b, 0); - } - TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - - EXPECT_TRUE(clusters.empty()); -} - TEST(XlaCompilationTest, ConstOp) { // valid data type { @@ -996,8 +942,10 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) { absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:XLA_GPU:1"; std::unique_ptr graph(new Graph(OpRegistry::Global())); - Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}); - Output b = ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}); + Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"), + ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2})); + Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"), + ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2})); Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a); Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b); diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index dbea9849e21..777a1562980 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -513,9 +513,10 @@ class ElementWiseFusionTest(test.TestCase): def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ( - "--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) + old_tf_xla_flags = os.environ.get("TF_XLA_FLAGS", "") + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " + "--tf_xla_min_cluster_size=2 " + "--tf_xla_cpu_global_jit " + old_tf_xla_flags) tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) @@ -525,6 +526,7 @@ class ElementWiseFusionTest(test.TestCase): self.assertEqual(2, tfef_count) self.assertAllClose(tf_op, tfef_op, rtol=1e-1) + os.environ["TF_XLA_FLAGS"] = old_tf_xla_flags class LazyCompilationTest(test.TestCase): diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 34f2465ba63..0611d6749fa 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -36,7 +36,7 @@ class RandomOpsTest(xla_test.XLATestCase): def _random_types(self): return set(self.numeric_types) - set( - self.complex_types) - {np.uint8, np.int8} + self.complex_types) - {np.uint64, np.int64, np.uint8, np.int8} def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index d2e715025cf..e64aa26cd4b 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.compiler.tests import xla_test @@ -1073,4 +1074,6 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertEqual(size1_v, 4) if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 3c0c36d0c4d..e07b150d601 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op @@ -211,4 +212,6 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(t, [0., 0., 0.]) if __name__ == "__main__": + os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' + + os.environ.get('TF_XLA_FLAGS', '')) test.main() diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index 84237ed2186..15a31111cb6 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -246,4 +246,6 @@ def is_compile_on_demand(): if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/python/keras/layers/lstm_v2_test.py b/tensorflow/python/keras/layers/lstm_v2_test.py index 5bafb56ba2d..26f588e3d2e 100644 --- a/tensorflow/python/keras/layers/lstm_v2_test.py +++ b/tensorflow/python/keras/layers/lstm_v2_test.py @@ -341,8 +341,8 @@ class LSTMV2Test(keras_parameterized.TestCase): cudnn_model.fit(x_train, y_train) y_4 = cudnn_model.predict(x_train) - self.assertAllClose(y_1, y_3, rtol=1e-5, atol=1e-5) - self.assertAllClose(y_2, y_4, rtol=1e-5, atol=1e-5) + self.assertAllClose(y_1, y_3, rtol=1e-5, atol=2e-5) + self.assertAllClose(y_2, y_4, rtol=1e-5, atol=2e-5) @parameterized.named_parameters(('v0', 0), ('v1', 1), ('v2', 2)) def DISABLED_test_implementation_mode_LSTM(self, implementation_mode):