diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc index 6df4aa2380e..e58c20dd70a 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc @@ -55,20 +55,19 @@ StatusOr CloneConstantsForBetterClusteringPass::CloneNode( } namespace { -// We only clone host constants for now since we want to avoid increasing memory -// pressure on GPUs. -StatusOr IsSmallHostConstant(Node* n) { - if (!n->IsConstant()) { - return false; +StatusOr IsConstantOnHost(Node* n) { + if (n->output_type(0) == DT_INT32) { + // TensorFlow always puts int32 tensors on the host. + return true; } DeviceNameUtils::ParsedName parsed; TF_RET_CHECK( DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed)); - if (parsed.type != DEVICE_CPU) { - return false; - } + return parsed.type == DEVICE_CPU; +} +StatusOr IsConstantSmall(Node* n) { const TensorProto* proto = nullptr; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); @@ -86,6 +85,21 @@ StatusOr IsSmallHostConstant(Node* n) { return total_elements < kSmallTensorThreshold; } +// We only clone host constants for now since we want to avoid increasing memory +// pressure on GPUs. +StatusOr IsSmallHostConstant(Node* n) { + if (!n->IsConstant()) { + return false; + } + + TF_ASSIGN_OR_RETURN(bool is_constant_on_host, IsConstantOnHost(n)); + if (!is_constant_on_host) { + return false; + } + + return IsConstantSmall(n); +} + bool IsInPlaceOp(absl::string_view op_name) { return op_name == "InplaceUpdate" || op_name == "InplaceAdd" || op_name == "InplaceSub"; diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc index 31543d1c3f8..26344721b3e 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -60,7 +61,7 @@ Status CloneConstantsForBetterClustering(const Scope& s, const char* kCPU = "/job:localhost/replica:0/task:0/device:CPU:0"; const char* kGPU = "/job:localhost/replica:0/task:0/device:GPU:0"; -TEST(CloneConstantsForBetterClusteringTest, Basic) { +TEST(CloneConstantsForBetterClusteringTest, HostConstantPlacedOnCpu) { Scope root = Scope::NewRootScope().ExitOnError(); Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU); @@ -87,7 +88,7 @@ TEST(CloneConstantsForBetterClusteringTest, Basic) { EXPECT_NE(tr0_perm.node, tr1_perm.node); } -TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) { +TEST(CloneConstantsForBetterClusteringTest, HostConstantPlacedOnGpu) { Scope root = Scope::NewRootScope().ExitOnError(); Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); @@ -110,6 +111,38 @@ TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) { OutputTensor tr1_perm; TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm)); + EXPECT_NE(tr0_perm.node, tr1_perm.node); +} + +TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + + Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT); + Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT); + + Output perm_f32 = ops::Const(on_gpu.WithOpName("perm"), {3.0, 1.0, 2.0, 0.0}); + Output perm_int0 = + ops::Cast(on_gpu.WithOpName("perm_cast_0"), perm_f32, DT_INT32); + Output perm_int1 = + ops::Cast(on_gpu.WithOpName("perm_cast_1"), perm_f32, DT_INT32); + + { + Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm_int0); + Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm_int1); + } + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor tr0_perm; + TF_ASSERT_OK( + FindNodeByName(result.get(), "perm_cast_0")->input_tensor(0, &tr0_perm)); + + OutputTensor tr1_perm; + TF_ASSERT_OK( + FindNodeByName(result.get(), "perm_cast_1")->input_tensor(0, &tr1_perm)); + EXPECT_EQ(tr0_perm.node, tr1_perm.node); } diff --git a/tensorflow/compiler/jit/tests/keras_imagenet_main.golden_summary b/tensorflow/compiler/jit/tests/keras_imagenet_main.golden_summary index 21b64e483ae..429628e5286 100644 --- a/tensorflow/compiler/jit/tests/keras_imagenet_main.golden_summary +++ b/tensorflow/compiler/jit/tests/keras_imagenet_main.golden_summary @@ -1,10 +1,10 @@ Clustered nodes: 2236 -Unclustered nodes: 618 +Unclustered nodes: 671 Number of clusters: 2 -unclustered size 618 +unclustered size 671 AssignAddVariableOp 1 - Const 120 + Const 173 DivNoNan 1 Identity 2 Merge 53 diff --git a/tensorflow/compiler/jit/tests/keras_imagenet_main_graph_mode.golden_summary b/tensorflow/compiler/jit/tests/keras_imagenet_main_graph_mode.golden_summary index de64044b7c1..4d43d17ee4a 100644 --- a/tensorflow/compiler/jit/tests/keras_imagenet_main_graph_mode.golden_summary +++ b/tensorflow/compiler/jit/tests/keras_imagenet_main_graph_mode.golden_summary @@ -1,10 +1,10 @@ Clustered nodes: 1904 -Unclustered nodes: 455 +Unclustered nodes: 509 Number of clusters: 1 -unclustered size 455 +unclustered size 509 AssignAddVariableOp 2 - Const 11 + Const 65 DivNoNan 1 Identity 1 NoOp 1 diff --git a/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary b/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary index aa2754cf4d1..0c514270adf 100644 --- a/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary +++ b/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary @@ -1,10 +1,11 @@ -Clustered nodes: 1962 -Unclustered nodes: 3974 -Number of clusters: 29 +Clustered nodes: 1970 +Unclustered nodes: 4376 +Number of clusters: 31 -unclustered size 3974 +unclustered size 4376 Add 17 AddN 1 + All 1 ApplyAdam 38 Assert 7 Assign 47 @@ -14,25 +15,23 @@ unclustered size 3974 Cast 8 ConcatOffset 10 ConcatV2 2 - Const 708 + Const 1111 ControlTrigger 5 - DynamicStitch 1 Enter 874 Equal 4 Exit 69 ExpandDims 8 - FloorDiv 1 FloorMod 1 GreaterEqual 7 Identity 113 IsVariableInitialized 1 IteratorGetNext 1 IteratorV2 1 - Less 8 + Less 9 LogicalAnd 3 LoopCond 8 Max 4 - Maximum 7 + Maximum 6 Merge 145 Minimum 5 Mul 8 @@ -63,29 +62,31 @@ unclustered size 3974 Unique 2 VariableV2 164 _Retval 5 -cluster 0 size 37 - All 3 - Cast 2 - ConcatV2 2 - Const 4 - Equal 3 - ExpandDims 4 - Fill 2 - GatherV2 2 - Identity 1 - LessEqual 1 +cluster 0 size 344 + Abs 40 + AddN 1 + Any 41 + Cast 40 + Const 3 + IsInf 1 + IsNan 40 + L2Loss 40 LogicalOr 1 - Max 1 + Max 41 + Minimum 1 + Mul 82 + Pack 3 + Reciprocal 2 + Reshape 2 ReverseSequence 1 - Shape 3 - StridedSlice 3 + Sqrt 1 + Sum 1 Transpose 3 - ZerosLike 1 -cluster 1 size 56 +cluster 1 size 57 BroadcastGradientArgs 1 Cast 5 ConcatV2 1 - Const 1 + Const 2 ExpandDims 3 Less 1 Mean 2 @@ -107,18 +108,7 @@ cluster 1 size 56 Tile 1 Transpose 2 Unpack 1 -cluster 2 size 14 - All 1 - ConcatV2 2 - Const 1 - Equal 1 - ExpandDims 2 - Fill 1 - ReverseSequence 1 - Shape 1 - StridedSlice 1 - Transpose 3 -cluster 3 size 29 +cluster 2 size 29 Cast 2 ConcatV2 2 Const 2 @@ -133,34 +123,38 @@ cluster 3 size 29 Range 1 Reshape 3 Shape 8 -cluster 4 size 344 - Abs 40 +cluster 3 size 9 AddN 1 - Any 41 - Cast 40 - Const 3 - IsInf 1 - IsNan 40 - L2Loss 40 - LogicalOr 1 - Max 41 - Minimum 1 - Mul 82 - Pack 3 - Reciprocal 2 - Reshape 2 - ReverseSequence 1 - Sqrt 1 + MatMul 2 + Mul 1 + Reshape 3 Sum 1 - Transpose 3 -cluster 5 size 3 - Shape 2 Transpose 1 -cluster 6 size 4 - All 1 - Less 1 - LogicalAnd 1 - LogicalNot 1 +cluster 4 size 6 + ReverseSequence 1 + Slice 2 + Transpose 3 +cluster 5 size 27 + All 2 + Cast 1 + ConcatV2 2 + Const 5 + Equal 2 + ExpandDims 4 + Fill 2 + GatherV2 1 + Identity 1 + ReverseSequence 1 + Shape 2 + StridedSlice 2 + Transpose 2 +cluster 6 size 6 + Cast 1 + GatherV2 1 + Shape 1 + StridedSlice 1 + Transpose 1 + ZerosLike 1 cluster 7 size 11 Cast 1 Const 4 @@ -169,14 +163,19 @@ cluster 7 size 11 Mul 2 Pow 1 Sub 1 -cluster 10 size 8 - Add 1 - All 2 - Const 2 - GreaterEqual 1 +cluster 8 size 4 + All 1 Less 1 + LogicalAnd 1 + LogicalNot 1 +cluster 9 size 7 + All 1 + Const 2 + Equal 1 + LessEqual 1 LogicalOr 1 -cluster 11 size 226 + Max 1 +cluster 10 size 226 Add 24 BatchMatMulV2 1 BiasAdd 8 @@ -204,14 +203,23 @@ cluster 11 size 226 StridedSlice 1 Sum 2 Tanh 17 -cluster 12 size 430 +cluster 11 size 5 + Add 1 + All 1 + Const 1 + GreaterEqual 1 + LogicalOr 1 +cluster 14 size 436 Add 22 AddN 41 BatchMatMulV2 2 BiasAddGrad 8 ConcatV2 14 - Const 25 + Const 28 + DynamicStitch 1 + FloorDiv 1 MatMul 20 + Maximum 1 Mul 74 NoOp 13 Reshape 86 @@ -224,35 +232,6 @@ cluster 12 size 430 TanhGrad 17 Tile 2 ZerosLike 1 -cluster 13 size 20 - Add 2 - BiasAdd 1 - ConcatV2 1 - Const 1 - GreaterEqual 1 - MatMul 1 - Mul 3 - Select 3 - Shape 1 - Sigmoid 3 - Split 1 - Tanh 2 -cluster 14 size 52 - Add 2 - AddN 4 - BiasAddGrad 1 - Cast 1 - ConcatV2 1 - Const 3 - MatMul 2 - Mul 6 - NoOp 2 - Reshape 9 - Select 5 - SigmoidGrad 3 - Slice 2 - Sum 9 - TanhGrad 2 cluster 15 size 20 Add 2 BiasAdd 1 @@ -266,11 +245,7 @@ cluster 15 size 20 Sigmoid 3 Split 1 Tanh 2 -cluster 16 size 6 - ReverseSequence 1 - Slice 2 - Transpose 3 -cluster 17 size 52 +cluster 16 size 52 Add 2 AddN 4 BiasAddGrad 1 @@ -286,7 +261,47 @@ cluster 17 size 52 Slice 2 Sum 9 TanhGrad 2 -cluster 19 size 25 +cluster 17 size 15 + All 1 + ConcatV2 2 + Const 2 + Equal 1 + ExpandDims 2 + Fill 1 + ReverseSequence 1 + Shape 1 + StridedSlice 1 + Transpose 3 +cluster 18 size 20 + Add 2 + BiasAdd 1 + ConcatV2 1 + Const 1 + GreaterEqual 1 + MatMul 1 + Mul 3 + Select 3 + Shape 1 + Sigmoid 3 + Split 1 + Tanh 2 +cluster 19 size 52 + Add 2 + AddN 4 + BiasAddGrad 1 + Cast 1 + ConcatV2 1 + Const 3 + MatMul 2 + Mul 6 + NoOp 2 + Reshape 9 + Select 5 + SigmoidGrad 3 + Slice 2 + Sum 9 + TanhGrad 2 +cluster 21 size 25 Add 2 BiasAdd 1 Cast 1 @@ -300,7 +315,20 @@ cluster 19 size 25 Snapshot 1 Split 1 Tanh 2 -cluster 20 size 363 +cluster 22 size 23 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 3 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 +cluster 23 size 363 Add 12 AddN 28 BiasAddGrad 6 @@ -316,97 +344,80 @@ cluster 20 size 363 Slice 12 Sum 76 TanhGrad 12 -cluster 21 size 22 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - MatMul 1 - Mul 5 - Select 2 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 22 size 22 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - MatMul 1 - Mul 5 - Select 2 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 23 size 22 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - MatMul 1 - Mul 5 - Select 2 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 24 size 22 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - MatMul 1 - Mul 5 - Select 2 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 25 size 23 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - MatMul 1 - Mul 5 - Select 3 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 26 size 9 - AddN 1 - MatMul 2 - Mul 1 - Reshape 3 - Sum 1 +cluster 24 size 3 + Shape 2 Transpose 1 -cluster 27 size 9 +cluster 25 size 22 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 2 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 +cluster 26 size 22 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 2 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 +cluster 27 size 22 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 2 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 +cluster 28 size 22 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 2 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 +cluster 29 size 9 Add 1 Mul 2 RealDiv 2 Sqrt 2 Sub 2 -cluster 28 size 9 +cluster 30 size 9 Add 1 Mul 2 RealDiv 2 Sqrt 2 Sub 2 -cluster 29 size 4 +cluster 31 size 4 Mul 3 UnsortedSegmentSum 1 -cluster 30 size 4 +cluster 32 size 4 Mul 3 UnsortedSegmentSum 1 -cluster 31 size 116 +cluster 33 size 116 Cast 38 Const 2 Maximum 38