Make constant cloning more aggressive (i.e. make it clone more often)
Constant nodes placed on the GPU can also be "host constants" because int32 constants are produced on the host. PiperOrigin-RevId: 273345880
This commit is contained in:
parent
11cadc64c5
commit
718ecede1c
@ -55,20 +55,19 @@ StatusOr<Node*> CloneConstantsForBetterClusteringPass::CloneNode(
|
||||
}
|
||||
|
||||
namespace {
|
||||
// We only clone host constants for now since we want to avoid increasing memory
|
||||
// pressure on GPUs.
|
||||
StatusOr<bool> IsSmallHostConstant(Node* n) {
|
||||
if (!n->IsConstant()) {
|
||||
return false;
|
||||
StatusOr<bool> IsConstantOnHost(Node* n) {
|
||||
if (n->output_type(0) == DT_INT32) {
|
||||
// TensorFlow always puts int32 tensors on the host.
|
||||
return true;
|
||||
}
|
||||
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
TF_RET_CHECK(
|
||||
DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed));
|
||||
if (parsed.type != DEVICE_CPU) {
|
||||
return false;
|
||||
}
|
||||
return parsed.type == DEVICE_CPU;
|
||||
}
|
||||
|
||||
StatusOr<bool> IsConstantSmall(Node* n) {
|
||||
const TensorProto* proto = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
|
||||
|
||||
@ -86,6 +85,21 @@ StatusOr<bool> IsSmallHostConstant(Node* n) {
|
||||
return total_elements < kSmallTensorThreshold;
|
||||
}
|
||||
|
||||
// We only clone host constants for now since we want to avoid increasing memory
|
||||
// pressure on GPUs.
|
||||
StatusOr<bool> IsSmallHostConstant(Node* n) {
|
||||
if (!n->IsConstant()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool is_constant_on_host, IsConstantOnHost(n));
|
||||
if (!is_constant_on_host) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return IsConstantSmall(n);
|
||||
}
|
||||
|
||||
bool IsInPlaceOp(absl::string_view op_name) {
|
||||
return op_name == "InplaceUpdate" || op_name == "InplaceAdd" ||
|
||||
op_name == "InplaceSub";
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/compiler/jit/node_matchers.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -60,7 +61,7 @@ Status CloneConstantsForBetterClustering(const Scope& s,
|
||||
const char* kCPU = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char* kGPU = "/job:localhost/replica:0/task:0/device:GPU:0";
|
||||
|
||||
TEST(CloneConstantsForBetterClusteringTest, Basic) {
|
||||
TEST(CloneConstantsForBetterClusteringTest, HostConstantPlacedOnCpu) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
|
||||
Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU);
|
||||
@ -87,7 +88,7 @@ TEST(CloneConstantsForBetterClusteringTest, Basic) {
|
||||
EXPECT_NE(tr0_perm.node, tr1_perm.node);
|
||||
}
|
||||
|
||||
TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) {
|
||||
TEST(CloneConstantsForBetterClusteringTest, HostConstantPlacedOnGpu) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
|
||||
|
||||
@ -110,6 +111,38 @@ TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) {
|
||||
OutputTensor tr1_perm;
|
||||
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
|
||||
|
||||
EXPECT_NE(tr0_perm.node, tr1_perm.node);
|
||||
}
|
||||
|
||||
TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
|
||||
|
||||
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
|
||||
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
|
||||
|
||||
Output perm_f32 = ops::Const(on_gpu.WithOpName("perm"), {3.0, 1.0, 2.0, 0.0});
|
||||
Output perm_int0 =
|
||||
ops::Cast(on_gpu.WithOpName("perm_cast_0"), perm_f32, DT_INT32);
|
||||
Output perm_int1 =
|
||||
ops::Cast(on_gpu.WithOpName("perm_cast_1"), perm_f32, DT_INT32);
|
||||
|
||||
{
|
||||
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm_int0);
|
||||
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm_int1);
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
|
||||
|
||||
OutputTensor tr0_perm;
|
||||
TF_ASSERT_OK(
|
||||
FindNodeByName(result.get(), "perm_cast_0")->input_tensor(0, &tr0_perm));
|
||||
|
||||
OutputTensor tr1_perm;
|
||||
TF_ASSERT_OK(
|
||||
FindNodeByName(result.get(), "perm_cast_1")->input_tensor(0, &tr1_perm));
|
||||
|
||||
EXPECT_EQ(tr0_perm.node, tr1_perm.node);
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
Clustered nodes: 2236
|
||||
Unclustered nodes: 618
|
||||
Unclustered nodes: 671
|
||||
Number of clusters: 2
|
||||
|
||||
unclustered size 618
|
||||
unclustered size 671
|
||||
AssignAddVariableOp 1
|
||||
Const 120
|
||||
Const 173
|
||||
DivNoNan 1
|
||||
Identity 2
|
||||
Merge 53
|
||||
|
@ -1,10 +1,10 @@
|
||||
Clustered nodes: 1904
|
||||
Unclustered nodes: 455
|
||||
Unclustered nodes: 509
|
||||
Number of clusters: 1
|
||||
|
||||
unclustered size 455
|
||||
unclustered size 509
|
||||
AssignAddVariableOp 2
|
||||
Const 11
|
||||
Const 65
|
||||
DivNoNan 1
|
||||
Identity 1
|
||||
NoOp 1
|
||||
|
@ -1,10 +1,11 @@
|
||||
Clustered nodes: 1962
|
||||
Unclustered nodes: 3974
|
||||
Number of clusters: 29
|
||||
Clustered nodes: 1970
|
||||
Unclustered nodes: 4376
|
||||
Number of clusters: 31
|
||||
|
||||
unclustered size 3974
|
||||
unclustered size 4376
|
||||
Add 17
|
||||
AddN 1
|
||||
All 1
|
||||
ApplyAdam 38
|
||||
Assert 7
|
||||
Assign 47
|
||||
@ -14,25 +15,23 @@ unclustered size 3974
|
||||
Cast 8
|
||||
ConcatOffset 10
|
||||
ConcatV2 2
|
||||
Const 708
|
||||
Const 1111
|
||||
ControlTrigger 5
|
||||
DynamicStitch 1
|
||||
Enter 874
|
||||
Equal 4
|
||||
Exit 69
|
||||
ExpandDims 8
|
||||
FloorDiv 1
|
||||
FloorMod 1
|
||||
GreaterEqual 7
|
||||
Identity 113
|
||||
IsVariableInitialized 1
|
||||
IteratorGetNext 1
|
||||
IteratorV2 1
|
||||
Less 8
|
||||
Less 9
|
||||
LogicalAnd 3
|
||||
LoopCond 8
|
||||
Max 4
|
||||
Maximum 7
|
||||
Maximum 6
|
||||
Merge 145
|
||||
Minimum 5
|
||||
Mul 8
|
||||
@ -63,29 +62,31 @@ unclustered size 3974
|
||||
Unique 2
|
||||
VariableV2 164
|
||||
_Retval 5
|
||||
cluster 0 size 37
|
||||
All 3
|
||||
Cast 2
|
||||
ConcatV2 2
|
||||
Const 4
|
||||
Equal 3
|
||||
ExpandDims 4
|
||||
Fill 2
|
||||
GatherV2 2
|
||||
Identity 1
|
||||
LessEqual 1
|
||||
cluster 0 size 344
|
||||
Abs 40
|
||||
AddN 1
|
||||
Any 41
|
||||
Cast 40
|
||||
Const 3
|
||||
IsInf 1
|
||||
IsNan 40
|
||||
L2Loss 40
|
||||
LogicalOr 1
|
||||
Max 1
|
||||
Max 41
|
||||
Minimum 1
|
||||
Mul 82
|
||||
Pack 3
|
||||
Reciprocal 2
|
||||
Reshape 2
|
||||
ReverseSequence 1
|
||||
Shape 3
|
||||
StridedSlice 3
|
||||
Sqrt 1
|
||||
Sum 1
|
||||
Transpose 3
|
||||
ZerosLike 1
|
||||
cluster 1 size 56
|
||||
cluster 1 size 57
|
||||
BroadcastGradientArgs 1
|
||||
Cast 5
|
||||
ConcatV2 1
|
||||
Const 1
|
||||
Const 2
|
||||
ExpandDims 3
|
||||
Less 1
|
||||
Mean 2
|
||||
@ -107,18 +108,7 @@ cluster 1 size 56
|
||||
Tile 1
|
||||
Transpose 2
|
||||
Unpack 1
|
||||
cluster 2 size 14
|
||||
All 1
|
||||
ConcatV2 2
|
||||
Const 1
|
||||
Equal 1
|
||||
ExpandDims 2
|
||||
Fill 1
|
||||
ReverseSequence 1
|
||||
Shape 1
|
||||
StridedSlice 1
|
||||
Transpose 3
|
||||
cluster 3 size 29
|
||||
cluster 2 size 29
|
||||
Cast 2
|
||||
ConcatV2 2
|
||||
Const 2
|
||||
@ -133,34 +123,38 @@ cluster 3 size 29
|
||||
Range 1
|
||||
Reshape 3
|
||||
Shape 8
|
||||
cluster 4 size 344
|
||||
Abs 40
|
||||
cluster 3 size 9
|
||||
AddN 1
|
||||
Any 41
|
||||
Cast 40
|
||||
Const 3
|
||||
IsInf 1
|
||||
IsNan 40
|
||||
L2Loss 40
|
||||
LogicalOr 1
|
||||
Max 41
|
||||
Minimum 1
|
||||
Mul 82
|
||||
Pack 3
|
||||
Reciprocal 2
|
||||
Reshape 2
|
||||
ReverseSequence 1
|
||||
Sqrt 1
|
||||
MatMul 2
|
||||
Mul 1
|
||||
Reshape 3
|
||||
Sum 1
|
||||
Transpose 3
|
||||
cluster 5 size 3
|
||||
Shape 2
|
||||
Transpose 1
|
||||
cluster 6 size 4
|
||||
All 1
|
||||
Less 1
|
||||
LogicalAnd 1
|
||||
LogicalNot 1
|
||||
cluster 4 size 6
|
||||
ReverseSequence 1
|
||||
Slice 2
|
||||
Transpose 3
|
||||
cluster 5 size 27
|
||||
All 2
|
||||
Cast 1
|
||||
ConcatV2 2
|
||||
Const 5
|
||||
Equal 2
|
||||
ExpandDims 4
|
||||
Fill 2
|
||||
GatherV2 1
|
||||
Identity 1
|
||||
ReverseSequence 1
|
||||
Shape 2
|
||||
StridedSlice 2
|
||||
Transpose 2
|
||||
cluster 6 size 6
|
||||
Cast 1
|
||||
GatherV2 1
|
||||
Shape 1
|
||||
StridedSlice 1
|
||||
Transpose 1
|
||||
ZerosLike 1
|
||||
cluster 7 size 11
|
||||
Cast 1
|
||||
Const 4
|
||||
@ -169,14 +163,19 @@ cluster 7 size 11
|
||||
Mul 2
|
||||
Pow 1
|
||||
Sub 1
|
||||
cluster 10 size 8
|
||||
Add 1
|
||||
All 2
|
||||
Const 2
|
||||
GreaterEqual 1
|
||||
cluster 8 size 4
|
||||
All 1
|
||||
Less 1
|
||||
LogicalAnd 1
|
||||
LogicalNot 1
|
||||
cluster 9 size 7
|
||||
All 1
|
||||
Const 2
|
||||
Equal 1
|
||||
LessEqual 1
|
||||
LogicalOr 1
|
||||
cluster 11 size 226
|
||||
Max 1
|
||||
cluster 10 size 226
|
||||
Add 24
|
||||
BatchMatMulV2 1
|
||||
BiasAdd 8
|
||||
@ -204,14 +203,23 @@ cluster 11 size 226
|
||||
StridedSlice 1
|
||||
Sum 2
|
||||
Tanh 17
|
||||
cluster 12 size 430
|
||||
cluster 11 size 5
|
||||
Add 1
|
||||
All 1
|
||||
Const 1
|
||||
GreaterEqual 1
|
||||
LogicalOr 1
|
||||
cluster 14 size 436
|
||||
Add 22
|
||||
AddN 41
|
||||
BatchMatMulV2 2
|
||||
BiasAddGrad 8
|
||||
ConcatV2 14
|
||||
Const 25
|
||||
Const 28
|
||||
DynamicStitch 1
|
||||
FloorDiv 1
|
||||
MatMul 20
|
||||
Maximum 1
|
||||
Mul 74
|
||||
NoOp 13
|
||||
Reshape 86
|
||||
@ -224,35 +232,6 @@ cluster 12 size 430
|
||||
TanhGrad 17
|
||||
Tile 2
|
||||
ZerosLike 1
|
||||
cluster 13 size 20
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
ConcatV2 1
|
||||
Const 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 3
|
||||
Select 3
|
||||
Shape 1
|
||||
Sigmoid 3
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 14 size 52
|
||||
Add 2
|
||||
AddN 4
|
||||
BiasAddGrad 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
Const 3
|
||||
MatMul 2
|
||||
Mul 6
|
||||
NoOp 2
|
||||
Reshape 9
|
||||
Select 5
|
||||
SigmoidGrad 3
|
||||
Slice 2
|
||||
Sum 9
|
||||
TanhGrad 2
|
||||
cluster 15 size 20
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
@ -266,11 +245,7 @@ cluster 15 size 20
|
||||
Sigmoid 3
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 16 size 6
|
||||
ReverseSequence 1
|
||||
Slice 2
|
||||
Transpose 3
|
||||
cluster 17 size 52
|
||||
cluster 16 size 52
|
||||
Add 2
|
||||
AddN 4
|
||||
BiasAddGrad 1
|
||||
@ -286,7 +261,47 @@ cluster 17 size 52
|
||||
Slice 2
|
||||
Sum 9
|
||||
TanhGrad 2
|
||||
cluster 19 size 25
|
||||
cluster 17 size 15
|
||||
All 1
|
||||
ConcatV2 2
|
||||
Const 2
|
||||
Equal 1
|
||||
ExpandDims 2
|
||||
Fill 1
|
||||
ReverseSequence 1
|
||||
Shape 1
|
||||
StridedSlice 1
|
||||
Transpose 3
|
||||
cluster 18 size 20
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
ConcatV2 1
|
||||
Const 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 3
|
||||
Select 3
|
||||
Shape 1
|
||||
Sigmoid 3
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 19 size 52
|
||||
Add 2
|
||||
AddN 4
|
||||
BiasAddGrad 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
Const 3
|
||||
MatMul 2
|
||||
Mul 6
|
||||
NoOp 2
|
||||
Reshape 9
|
||||
Select 5
|
||||
SigmoidGrad 3
|
||||
Slice 2
|
||||
Sum 9
|
||||
TanhGrad 2
|
||||
cluster 21 size 25
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
@ -300,7 +315,20 @@ cluster 19 size 25
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 20 size 363
|
||||
cluster 22 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 3
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 23 size 363
|
||||
Add 12
|
||||
AddN 28
|
||||
BiasAddGrad 6
|
||||
@ -316,97 +344,80 @@ cluster 20 size 363
|
||||
Slice 12
|
||||
Sum 76
|
||||
TanhGrad 12
|
||||
cluster 21 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 22 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 23 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 24 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 25 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 3
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 26 size 9
|
||||
AddN 1
|
||||
MatMul 2
|
||||
Mul 1
|
||||
Reshape 3
|
||||
Sum 1
|
||||
cluster 24 size 3
|
||||
Shape 2
|
||||
Transpose 1
|
||||
cluster 27 size 9
|
||||
cluster 25 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 26 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 27 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 28 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 29 size 9
|
||||
Add 1
|
||||
Mul 2
|
||||
RealDiv 2
|
||||
Sqrt 2
|
||||
Sub 2
|
||||
cluster 28 size 9
|
||||
cluster 30 size 9
|
||||
Add 1
|
||||
Mul 2
|
||||
RealDiv 2
|
||||
Sqrt 2
|
||||
Sub 2
|
||||
cluster 29 size 4
|
||||
cluster 31 size 4
|
||||
Mul 3
|
||||
UnsortedSegmentSum 1
|
||||
cluster 30 size 4
|
||||
cluster 32 size 4
|
||||
Mul 3
|
||||
UnsortedSegmentSum 1
|
||||
cluster 31 size 116
|
||||
cluster 33 size 116
|
||||
Cast 38
|
||||
Const 2
|
||||
Maximum 38
|
||||
|
Loading…
Reference in New Issue
Block a user