Make constant cloning more aggressive (i.e. make it clone more often)

Constant nodes placed on the GPU can also be "host constants" because int32
constants are produced on the host.

PiperOrigin-RevId: 273345880
This commit is contained in:
Sanjoy Das 2019-10-07 12:07:05 -07:00 committed by TensorFlower Gardener
parent 11cadc64c5
commit 718ecede1c
5 changed files with 257 additions and 199 deletions

View File

@ -55,20 +55,19 @@ StatusOr<Node*> CloneConstantsForBetterClusteringPass::CloneNode(
}
namespace {
// We only clone host constants for now since we want to avoid increasing memory
// pressure on GPUs.
StatusOr<bool> IsSmallHostConstant(Node* n) {
if (!n->IsConstant()) {
return false;
StatusOr<bool> IsConstantOnHost(Node* n) {
if (n->output_type(0) == DT_INT32) {
// TensorFlow always puts int32 tensors on the host.
return true;
}
DeviceNameUtils::ParsedName parsed;
TF_RET_CHECK(
DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed));
if (parsed.type != DEVICE_CPU) {
return false;
}
return parsed.type == DEVICE_CPU;
}
StatusOr<bool> IsConstantSmall(Node* n) {
const TensorProto* proto = nullptr;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
@ -86,6 +85,21 @@ StatusOr<bool> IsSmallHostConstant(Node* n) {
return total_elements < kSmallTensorThreshold;
}
// We only clone host constants for now since we want to avoid increasing memory
// pressure on GPUs.
StatusOr<bool> IsSmallHostConstant(Node* n) {
if (!n->IsConstant()) {
return false;
}
TF_ASSIGN_OR_RETURN(bool is_constant_on_host, IsConstantOnHost(n));
if (!is_constant_on_host) {
return false;
}
return IsConstantSmall(n);
}
bool IsInPlaceOp(absl::string_view op_name) {
return op_name == "InplaceUpdate" || op_name == "InplaceAdd" ||
op_name == "InplaceSub";

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/compiler/jit/node_matchers.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -60,7 +61,7 @@ Status CloneConstantsForBetterClustering(const Scope& s,
const char* kCPU = "/job:localhost/replica:0/task:0/device:CPU:0";
const char* kGPU = "/job:localhost/replica:0/task:0/device:GPU:0";
TEST(CloneConstantsForBetterClusteringTest, Basic) {
TEST(CloneConstantsForBetterClusteringTest, HostConstantPlacedOnCpu) {
Scope root = Scope::NewRootScope().ExitOnError();
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU);
@ -87,7 +88,7 @@ TEST(CloneConstantsForBetterClusteringTest, Basic) {
EXPECT_NE(tr0_perm.node, tr1_perm.node);
}
TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) {
TEST(CloneConstantsForBetterClusteringTest, HostConstantPlacedOnGpu) {
Scope root = Scope::NewRootScope().ExitOnError();
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
@ -110,6 +111,38 @@ TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) {
OutputTensor tr1_perm;
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
EXPECT_NE(tr0_perm.node, tr1_perm.node);
}
TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) {
Scope root = Scope::NewRootScope().ExitOnError();
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
Output perm_f32 = ops::Const(on_gpu.WithOpName("perm"), {3.0, 1.0, 2.0, 0.0});
Output perm_int0 =
ops::Cast(on_gpu.WithOpName("perm_cast_0"), perm_f32, DT_INT32);
Output perm_int1 =
ops::Cast(on_gpu.WithOpName("perm_cast_1"), perm_f32, DT_INT32);
{
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm_int0);
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm_int1);
}
std::unique_ptr<Graph> result;
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
OutputTensor tr0_perm;
TF_ASSERT_OK(
FindNodeByName(result.get(), "perm_cast_0")->input_tensor(0, &tr0_perm));
OutputTensor tr1_perm;
TF_ASSERT_OK(
FindNodeByName(result.get(), "perm_cast_1")->input_tensor(0, &tr1_perm));
EXPECT_EQ(tr0_perm.node, tr1_perm.node);
}

View File

@ -1,10 +1,10 @@
Clustered nodes: 2236
Unclustered nodes: 618
Unclustered nodes: 671
Number of clusters: 2
unclustered size 618
unclustered size 671
AssignAddVariableOp 1
Const 120
Const 173
DivNoNan 1
Identity 2
Merge 53

View File

@ -1,10 +1,10 @@
Clustered nodes: 1904
Unclustered nodes: 455
Unclustered nodes: 509
Number of clusters: 1
unclustered size 455
unclustered size 509
AssignAddVariableOp 2
Const 11
Const 65
DivNoNan 1
Identity 1
NoOp 1

View File

@ -1,10 +1,11 @@
Clustered nodes: 1962
Unclustered nodes: 3974
Number of clusters: 29
Clustered nodes: 1970
Unclustered nodes: 4376
Number of clusters: 31
unclustered size 3974
unclustered size 4376
Add 17
AddN 1
All 1
ApplyAdam 38
Assert 7
Assign 47
@ -14,25 +15,23 @@ unclustered size 3974
Cast 8
ConcatOffset 10
ConcatV2 2
Const 708
Const 1111
ControlTrigger 5
DynamicStitch 1
Enter 874
Equal 4
Exit 69
ExpandDims 8
FloorDiv 1
FloorMod 1
GreaterEqual 7
Identity 113
IsVariableInitialized 1
IteratorGetNext 1
IteratorV2 1
Less 8
Less 9
LogicalAnd 3
LoopCond 8
Max 4
Maximum 7
Maximum 6
Merge 145
Minimum 5
Mul 8
@ -63,29 +62,31 @@ unclustered size 3974
Unique 2
VariableV2 164
_Retval 5
cluster 0 size 37
All 3
Cast 2
ConcatV2 2
Const 4
Equal 3
ExpandDims 4
Fill 2
GatherV2 2
Identity 1
LessEqual 1
cluster 0 size 344
Abs 40
AddN 1
Any 41
Cast 40
Const 3
IsInf 1
IsNan 40
L2Loss 40
LogicalOr 1
Max 1
Max 41
Minimum 1
Mul 82
Pack 3
Reciprocal 2
Reshape 2
ReverseSequence 1
Shape 3
StridedSlice 3
Sqrt 1
Sum 1
Transpose 3
ZerosLike 1
cluster 1 size 56
cluster 1 size 57
BroadcastGradientArgs 1
Cast 5
ConcatV2 1
Const 1
Const 2
ExpandDims 3
Less 1
Mean 2
@ -107,18 +108,7 @@ cluster 1 size 56
Tile 1
Transpose 2
Unpack 1
cluster 2 size 14
All 1
ConcatV2 2
Const 1
Equal 1
ExpandDims 2
Fill 1
ReverseSequence 1
Shape 1
StridedSlice 1
Transpose 3
cluster 3 size 29
cluster 2 size 29
Cast 2
ConcatV2 2
Const 2
@ -133,34 +123,38 @@ cluster 3 size 29
Range 1
Reshape 3
Shape 8
cluster 4 size 344
Abs 40
cluster 3 size 9
AddN 1
Any 41
Cast 40
Const 3
IsInf 1
IsNan 40
L2Loss 40
LogicalOr 1
Max 41
Minimum 1
Mul 82
Pack 3
Reciprocal 2
Reshape 2
ReverseSequence 1
Sqrt 1
MatMul 2
Mul 1
Reshape 3
Sum 1
Transpose 3
cluster 5 size 3
Shape 2
Transpose 1
cluster 6 size 4
All 1
Less 1
LogicalAnd 1
LogicalNot 1
cluster 4 size 6
ReverseSequence 1
Slice 2
Transpose 3
cluster 5 size 27
All 2
Cast 1
ConcatV2 2
Const 5
Equal 2
ExpandDims 4
Fill 2
GatherV2 1
Identity 1
ReverseSequence 1
Shape 2
StridedSlice 2
Transpose 2
cluster 6 size 6
Cast 1
GatherV2 1
Shape 1
StridedSlice 1
Transpose 1
ZerosLike 1
cluster 7 size 11
Cast 1
Const 4
@ -169,14 +163,19 @@ cluster 7 size 11
Mul 2
Pow 1
Sub 1
cluster 10 size 8
Add 1
All 2
Const 2
GreaterEqual 1
cluster 8 size 4
All 1
Less 1
LogicalAnd 1
LogicalNot 1
cluster 9 size 7
All 1
Const 2
Equal 1
LessEqual 1
LogicalOr 1
cluster 11 size 226
Max 1
cluster 10 size 226
Add 24
BatchMatMulV2 1
BiasAdd 8
@ -204,14 +203,23 @@ cluster 11 size 226
StridedSlice 1
Sum 2
Tanh 17
cluster 12 size 430
cluster 11 size 5
Add 1
All 1
Const 1
GreaterEqual 1
LogicalOr 1
cluster 14 size 436
Add 22
AddN 41
BatchMatMulV2 2
BiasAddGrad 8
ConcatV2 14
Const 25
Const 28
DynamicStitch 1
FloorDiv 1
MatMul 20
Maximum 1
Mul 74
NoOp 13
Reshape 86
@ -224,35 +232,6 @@ cluster 12 size 430
TanhGrad 17
Tile 2
ZerosLike 1
cluster 13 size 20
Add 2
BiasAdd 1
ConcatV2 1
Const 1
GreaterEqual 1
MatMul 1
Mul 3
Select 3
Shape 1
Sigmoid 3
Split 1
Tanh 2
cluster 14 size 52
Add 2
AddN 4
BiasAddGrad 1
Cast 1
ConcatV2 1
Const 3
MatMul 2
Mul 6
NoOp 2
Reshape 9
Select 5
SigmoidGrad 3
Slice 2
Sum 9
TanhGrad 2
cluster 15 size 20
Add 2
BiasAdd 1
@ -266,11 +245,7 @@ cluster 15 size 20
Sigmoid 3
Split 1
Tanh 2
cluster 16 size 6
ReverseSequence 1
Slice 2
Transpose 3
cluster 17 size 52
cluster 16 size 52
Add 2
AddN 4
BiasAddGrad 1
@ -286,7 +261,47 @@ cluster 17 size 52
Slice 2
Sum 9
TanhGrad 2
cluster 19 size 25
cluster 17 size 15
All 1
ConcatV2 2
Const 2
Equal 1
ExpandDims 2
Fill 1
ReverseSequence 1
Shape 1
StridedSlice 1
Transpose 3
cluster 18 size 20
Add 2
BiasAdd 1
ConcatV2 1
Const 1
GreaterEqual 1
MatMul 1
Mul 3
Select 3
Shape 1
Sigmoid 3
Split 1
Tanh 2
cluster 19 size 52
Add 2
AddN 4
BiasAddGrad 1
Cast 1
ConcatV2 1
Const 3
MatMul 2
Mul 6
NoOp 2
Reshape 9
Select 5
SigmoidGrad 3
Slice 2
Sum 9
TanhGrad 2
cluster 21 size 25
Add 2
BiasAdd 1
Cast 1
@ -300,7 +315,20 @@ cluster 19 size 25
Snapshot 1
Split 1
Tanh 2
cluster 20 size 363
cluster 22 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 3
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 23 size 363
Add 12
AddN 28
BiasAddGrad 6
@ -316,97 +344,80 @@ cluster 20 size 363
Slice 12
Sum 76
TanhGrad 12
cluster 21 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 22 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 23 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 24 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 25 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 3
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 26 size 9
AddN 1
MatMul 2
Mul 1
Reshape 3
Sum 1
cluster 24 size 3
Shape 2
Transpose 1
cluster 27 size 9
cluster 25 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 26 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 27 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 28 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 29 size 9
Add 1
Mul 2
RealDiv 2
Sqrt 2
Sub 2
cluster 28 size 9
cluster 30 size 9
Add 1
Mul 2
RealDiv 2
Sqrt 2
Sub 2
cluster 29 size 4
cluster 31 size 4
Mul 3
UnsortedSegmentSum 1
cluster 30 size 4
cluster 32 size 4
Mul 3
UnsortedSegmentSum 1
cluster 31 size 116
cluster 33 size 116
Cast 38
Const 2
Maximum 38