Refactor CalculateOutputSize() from VirtualScheduler protected member function to utils; Refactor EstimateSize() from memory_optimizer.cc to utils; some small changes for readability improvement
PiperOrigin-RevId: 216307257
This commit is contained in:
parent
d1f0494b89
commit
e27ee15fa4
@ -236,6 +236,7 @@ tf_cc_test(
|
||||
name = "virtual_scheduler_test",
|
||||
srcs = ["virtual_scheduler_test.cc"],
|
||||
deps = [
|
||||
":utils",
|
||||
":virtual_placer",
|
||||
":virtual_scheduler",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
|
||||
@ -74,7 +74,8 @@ static std::vector<TensorProto> ExtractTensors(const AttrValue& attr_value) {
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {}
|
||||
default: {
|
||||
}
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
@ -201,6 +202,43 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
|
||||
return inputs;
|
||||
}
|
||||
|
||||
int64 CalculateTensorSize(const OpInfo::TensorProperties& prop) {
|
||||
int64 size = DataTypeSize(BaseType(prop.dtype()));
|
||||
TensorShapeProto shape = prop.shape();
|
||||
|
||||
// Can't infer the size if the rank is unknown. It has to be at least a
|
||||
// scalar though.
|
||||
if (shape.unknown_rank()) {
|
||||
LOG(WARNING) << "CalculateTensorSize() -- unknown rank";
|
||||
return size;
|
||||
}
|
||||
|
||||
// If one of the dimensions is unknown statically, assume it's at least one.
|
||||
for (int i = 0; i < shape.dim_size(); ++i) {
|
||||
if (shape.dim(i).size() < 0) {
|
||||
shape.mutable_dim(i)->set_size(1);
|
||||
LOG(WARNING) << "CalculateTensorSize() -- unknown dim: " << i;
|
||||
}
|
||||
}
|
||||
|
||||
int64 num_elems = TensorShape(shape).num_elements();
|
||||
return num_elems * size;
|
||||
}
|
||||
|
||||
int64 CalculateOutputSize(
|
||||
const std::vector<OpInfo::TensorProperties>& output_properties,
|
||||
const int port_num) {
|
||||
if (port_num < 0) return 4; // 4B for control dependency.
|
||||
|
||||
if (port_num >= output_properties.size()) {
|
||||
LOG(ERROR) << "CalculateOutputSize() -- port_num: " << port_num
|
||||
<< " >= output_properties.size(): " << output_properties.size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
return CalculateTensorSize(output_properties[port_num]);
|
||||
}
|
||||
|
||||
DeviceProperties GetDeviceInfo(const string& device_str) {
|
||||
DeviceProperties unknown;
|
||||
unknown.set_type("UNKNOWN");
|
||||
|
||||
@ -43,6 +43,17 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
|
||||
const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost,
|
||||
const std::unordered_map<string, const NodeDef*>& name_to_node);
|
||||
|
||||
// Returns the size of tensor (unit: bytes). For tensor shape with unknown rank,
|
||||
// it assumes the tensor to be scalar. For any unknown dimension, it assumes
|
||||
// size one.
|
||||
int64 CalculateTensorSize(const OpInfo::TensorProperties& prop);
|
||||
|
||||
// Returns the size of output at port_num (unit: bytes). A special case is
|
||||
// port_num -1, which is for control dependency and assumed to be 4 bytes.
|
||||
int64 CalculateOutputSize(
|
||||
const std::vector<OpInfo::TensorProperties>& output_properties,
|
||||
int port_num);
|
||||
|
||||
// Returns the DeviceProperties of the device on which 'node' runs.
|
||||
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node);
|
||||
DeviceProperties GetDeviceInfo(const string& device_str);
|
||||
|
||||
@ -26,36 +26,42 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class UtilsTest : public ::testing::Test {
|
||||
public:
|
||||
void CreateConstOp(const string& name, std::initializer_list<int64> dims,
|
||||
NodeDef* node) {
|
||||
Tensor tensor(DT_FLOAT, TensorShape(dims));
|
||||
for (int64 i = 0; i < tensor.NumElements(); ++i) {
|
||||
tensor.flat<float>()(i) = i / 10.0f;
|
||||
}
|
||||
TF_CHECK_OK(NodeDefBuilder(name, "Const")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("value", tensor)
|
||||
.Finalize(node));
|
||||
}
|
||||
namespace {
|
||||
|
||||
void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes,
|
||||
NodeDef* node) {
|
||||
TensorShape shape;
|
||||
shape.AddDim(sizes.size());
|
||||
Tensor tensor(DT_INT32, shape);
|
||||
for (int64 i = 0; i < tensor.NumElements(); ++i) {
|
||||
tensor.flat<int32>()(i) = sizes[i];
|
||||
}
|
||||
TF_CHECK_OK(NodeDefBuilder(name, "Const")
|
||||
.Attr("dtype", DT_INT32)
|
||||
.Attr("value", tensor)
|
||||
.Finalize(node));
|
||||
}
|
||||
};
|
||||
void CreateConstOp(const string& name, std::initializer_list<int64> dims,
|
||||
NodeDef* node) {
|
||||
Tensor tensor(DT_FLOAT, TensorShape(dims));
|
||||
for (int64 i = 0; i < tensor.NumElements(); ++i)
|
||||
tensor.flat<float>()(i) = i / 10.0f;
|
||||
TF_CHECK_OK(NodeDefBuilder(name, "Const")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("value", tensor)
|
||||
.Finalize(node));
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, ConvOpInfo) {
|
||||
void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes,
|
||||
NodeDef* node) {
|
||||
TensorShape shape;
|
||||
shape.AddDim(sizes.size());
|
||||
Tensor tensor(DT_INT32, shape);
|
||||
for (int64 i = 0; i < tensor.NumElements(); ++i)
|
||||
tensor.flat<int32>()(i) = sizes[i];
|
||||
TF_CHECK_OK(NodeDefBuilder(name, "Const")
|
||||
.Attr("dtype", DT_INT32)
|
||||
.Attr("value", tensor)
|
||||
.Finalize(node));
|
||||
}
|
||||
|
||||
// Helper method for converting shapes vector to TensorProperty.
|
||||
OpInfo::TensorProperties ShapeToTensorProperty(const std::vector<int>& shapes,
|
||||
const DataType& data_type) {
|
||||
OpInfo::TensorProperties prop;
|
||||
prop.set_dtype(data_type);
|
||||
for (int shape : shapes) prop.mutable_shape()->add_dim()->set_size(shape);
|
||||
return prop;
|
||||
}
|
||||
|
||||
TEST(UtilsTest, ConvOpInfo) {
|
||||
int batch = 32;
|
||||
int rows = 7;
|
||||
int cols = 9;
|
||||
@ -146,7 +152,7 @@ TEST_F(UtilsTest, ConvOpInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, TestSkipControlInput) {
|
||||
TEST(UtilsTest, TestSkipControlInput) {
|
||||
GraphDef graph;
|
||||
TF_CHECK_OK(NodeDefBuilder("constant", "Const")
|
||||
.Attr("dtype", DT_INT32)
|
||||
@ -172,6 +178,52 @@ TEST_F(UtilsTest, TestSkipControlInput) {
|
||||
EXPECT_TRUE(node_found);
|
||||
}
|
||||
|
||||
TEST(UtilsTest, CalculateTensorSize) {
|
||||
// Test normal usage.
|
||||
EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1,
|
||||
CalculateTensorSize(ShapeToTensorProperty({1}, DT_FLOAT)));
|
||||
EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4,
|
||||
CalculateTensorSize(ShapeToTensorProperty({4, 4}, DT_FLOAT)));
|
||||
EXPECT_EQ(DataTypeSize(DT_HALF) * 10 * 10 * 10,
|
||||
CalculateTensorSize(ShapeToTensorProperty({10, 10, 10}, DT_HALF)));
|
||||
EXPECT_EQ(
|
||||
DataTypeSize(DT_FLOAT) * 100 * 7 * 8 * 99,
|
||||
CalculateTensorSize(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)));
|
||||
|
||||
// Test unknown rank: assumes the tensor to be a scalar.
|
||||
OpInfo::TensorProperties t = ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT);
|
||||
t.mutable_shape()->set_unknown_rank(true);
|
||||
EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1, CalculateTensorSize(t));
|
||||
|
||||
// Test unknown shape: assumes unknown shape (-1) to have size 1.
|
||||
EXPECT_EQ(
|
||||
DataTypeSize(DT_FLOAT) * 1 * 7 * 8 * 99,
|
||||
CalculateTensorSize(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)));
|
||||
EXPECT_EQ(
|
||||
DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99,
|
||||
CalculateTensorSize(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)));
|
||||
}
|
||||
|
||||
TEST(UtilsTest, CalculateOutputSize) {
|
||||
// Create a set of tensor properties.
|
||||
std::vector<OpInfo::TensorProperties> output = {
|
||||
ShapeToTensorProperty({4, 4}, DT_FLOAT), // 0
|
||||
ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT) // 1
|
||||
};
|
||||
|
||||
// Test valid outputs.
|
||||
EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4, CalculateOutputSize(output, 0));
|
||||
EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99,
|
||||
CalculateOutputSize(output, 1));
|
||||
|
||||
// port_num -1 is for control dependency: hard coded 4B.
|
||||
EXPECT_EQ(4, CalculateOutputSize(output, -1));
|
||||
|
||||
// Invalid port_num (though it may be an error) shall yield zero
|
||||
// output size.
|
||||
EXPECT_EQ(0, CalculateOutputSize(output, 2));
|
||||
}
|
||||
|
||||
// Class for testing TensorSizeHistogram.
|
||||
class TestTensorSizeHistogram : public TensorSizeHistogram {
|
||||
public:
|
||||
@ -285,5 +337,7 @@ TEST(DeviceClassTest, GetDeviceClassForNonChannelDevice) {
|
||||
EXPECT_EQ("//GPU", GetDeviceClassForNonChannelDevice("/device:GPU:7"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
@ -473,6 +473,7 @@ Status VirtualScheduler::Init() {
|
||||
VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
|
||||
<< str_util::Join(feed_nodes, ",");
|
||||
}
|
||||
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -695,38 +696,6 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
int64 VirtualScheduler::CalculateOutputSize(
|
||||
const std::vector<OpInfo::TensorProperties>& output_properties,
|
||||
const int port_num) const {
|
||||
if (port_num < 0) {
|
||||
return 4; // 4B for control dependency.
|
||||
}
|
||||
|
||||
if (port_num >= output_properties.size()) {
|
||||
VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
|
||||
<< "port_num: " << port_num
|
||||
<< " >= output_properties.size(): " << output_properties.size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
const auto& output = output_properties[port_num];
|
||||
int64 output_size = DataTypeSize(BaseType(output.dtype()));
|
||||
|
||||
for (const auto& dim : output.shape().dim()) {
|
||||
auto dim_size = dim.size();
|
||||
if (dim_size < 0) {
|
||||
// Zero output size if there's any unknown dim.
|
||||
output_size = 0;
|
||||
VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
|
||||
<< "unknown dim: " << output_size;
|
||||
break;
|
||||
}
|
||||
output_size *= dim_size;
|
||||
}
|
||||
|
||||
return output_size;
|
||||
}
|
||||
|
||||
Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
|
||||
std::map<string, Costs>* op_cost) {
|
||||
auto it = op_cost->find(op_name);
|
||||
@ -744,7 +713,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||
const string& op_name = node->op();
|
||||
|
||||
// Also keep track of op counts and times per op (with their shapes).
|
||||
auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
|
||||
op_cost = CombineCosts(op_cost, node_costs);
|
||||
|
||||
// Also keep track of op counts and costs per op (with their shapes).
|
||||
OpContext op_context = GetCurrNode();
|
||||
string node_description = GetOpDescription(op_context.op_info);
|
||||
op_counts_[node_description] += 1;
|
||||
@ -752,9 +724,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
std::make_pair(node_costs.execution_time.asMicroSeconds().count(),
|
||||
!node_costs.inaccurate);
|
||||
|
||||
auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
|
||||
op_cost = CombineCosts(op_cost, node_costs);
|
||||
|
||||
// Update node and device states.
|
||||
auto& node_state = node_map_[node];
|
||||
auto& device = device_[node_state.device_name];
|
||||
@ -795,7 +764,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
<< ", scheduled: " << node_state.time_scheduled.count()
|
||||
<< ", finished: " << node_state.time_finished.count();
|
||||
|
||||
// Increment num_inputs_ready of the output nodes
|
||||
// Increment num_inputs_ready of the output nodes and maybe add to ready nodes
|
||||
for (const auto& port_num_output_pair : node_state.outputs) {
|
||||
for (auto* output_node : port_num_output_pair.second) {
|
||||
auto& output_state = node_map_[output_node];
|
||||
@ -812,7 +781,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
}
|
||||
}
|
||||
|
||||
// Increment num_outputs_executed of the input nodes.
|
||||
// Increment num_outputs_executed of the input nodes and maybe update memory.
|
||||
for (const auto& input_port : node_state.inputs) {
|
||||
auto* input = input_port.first;
|
||||
auto port = input_port.second;
|
||||
@ -841,7 +810,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the current node; assume FIFO.
|
||||
ready_nodes_->RemoveCurrNode();
|
||||
|
||||
return !ready_nodes_->Empty();
|
||||
@ -1007,7 +975,7 @@ Costs VirtualScheduler::Summary(RunMetadata* metadata) {
|
||||
return Summary();
|
||||
}
|
||||
|
||||
// Fill RunMetadata.
|
||||
// Fill RunMetadata's step_stats and partition_graphs fields.
|
||||
StepStats* stepstats = metadata->mutable_step_stats();
|
||||
for (const auto& device : device_) {
|
||||
GraphDef* device_partition_graph = metadata->add_partition_graphs();
|
||||
|
||||
@ -107,10 +107,10 @@ struct DeviceState {
|
||||
mem_usage_snapshot_at_peak;
|
||||
|
||||
Costs device_costs;
|
||||
std::map<string, Costs> op_to_cost; // Per-op cost.
|
||||
std::map<string, int64> op_to_memory; // Per-op memory usage at peak usage.
|
||||
int64 memory_usage;
|
||||
int64 max_memory_usage;
|
||||
std::map<string, Costs> op_to_cost; // Per-op cost.
|
||||
|
||||
int64 memory_usage; // Current temporary memory usage
|
||||
int64 max_memory_usage; // Max temporary memory usage
|
||||
|
||||
DeviceState() {
|
||||
device_costs = Costs::ZeroCosts();
|
||||
@ -283,13 +283,6 @@ class VirtualScheduler {
|
||||
return &node_map_;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Returns the size of output at port_num (unit: bytes). A special case is
|
||||
// port_num -1, which is for control dependency and assumed to be 4 bytes.
|
||||
int64 CalculateOutputSize(
|
||||
const std::vector<OpInfo::TensorProperties>& output_properties,
|
||||
const int port_num) const;
|
||||
|
||||
private:
|
||||
// Constants.
|
||||
const string kAttrInputSrc = "input_source_";
|
||||
@ -321,8 +314,11 @@ class VirtualScheduler {
|
||||
std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
|
||||
|
||||
// Stats:
|
||||
std::map<string, int> op_counts_; // Op counts with key with input shape.
|
||||
// Individual op costs (with input shapes).
|
||||
// Op counts with key with input shape.
|
||||
// Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]"
|
||||
std::map<string, int> op_counts_;
|
||||
// Individual op costs with key with input shape.
|
||||
// Integer field for execution time in micro seconds.
|
||||
// Boolean field for whether the cost is accurate.
|
||||
std::map<string, std::pair<int, bool>> op_costs_;
|
||||
|
||||
|
||||
@ -19,12 +19,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/costs/virtual_placer.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Class for testing virtual scheduler.
|
||||
class TestVirtualScheduler : public VirtualScheduler {
|
||||
public:
|
||||
@ -33,7 +35,6 @@ class TestVirtualScheduler : public VirtualScheduler {
|
||||
: VirtualScheduler(grappler_item, use_static_shapes, cluster,
|
||||
&ready_node_manager_) {}
|
||||
|
||||
FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
|
||||
FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
|
||||
FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
|
||||
FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
|
||||
@ -1034,17 +1035,6 @@ versions {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper method for converting shape vector to TensorProperty.
|
||||
OpInfo::TensorProperties ShapeToTensorProperty(
|
||||
const std::vector<int> shape, const DataType& data_type) const {
|
||||
OpInfo::TensorProperties tensor_property;
|
||||
tensor_property.set_dtype(data_type);
|
||||
for (const auto& x : shape) {
|
||||
tensor_property.mutable_shape()->add_dim()->set_size(x);
|
||||
}
|
||||
return tensor_property;
|
||||
}
|
||||
|
||||
// SetUp() inits cluster_ and placer_.
|
||||
std::unique_ptr<VirtualCluster> cluster_;
|
||||
std::unique_ptr<VirtualPlacer> placer_;
|
||||
@ -1729,38 +1719,6 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
|
||||
EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
|
||||
}
|
||||
|
||||
TEST_F(VirtualSchedulerTest, CalculateOutputSize) {
|
||||
// Init.
|
||||
CreateGrapplerItemWithAddN();
|
||||
InitScheduler();
|
||||
|
||||
// Create a set of tensor properties.
|
||||
std::vector<OpInfo::TensorProperties> output;
|
||||
output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0
|
||||
output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1
|
||||
output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2
|
||||
output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3
|
||||
output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4
|
||||
output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4
|
||||
|
||||
// port_num -1 is for control dependency: hard coded 4B.
|
||||
EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1));
|
||||
|
||||
// Test valid outputs.
|
||||
EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0));
|
||||
EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1));
|
||||
EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2));
|
||||
EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3));
|
||||
|
||||
// Any unknown shape (-1) shall yield zero output size.
|
||||
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4));
|
||||
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5));
|
||||
|
||||
// Invalid port_num (though it may be an error) shall yield zero
|
||||
// output size.
|
||||
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6));
|
||||
}
|
||||
|
||||
TEST_F(VirtualSchedulerTest, MemoryUsage) {
|
||||
// Init.
|
||||
CreateGrapplerItemWithAddN();
|
||||
@ -2041,7 +1999,7 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
|
||||
for (const auto& output_property : output_properties_) {
|
||||
output_properties.push_back(output_property);
|
||||
}
|
||||
return scheduler_->CalculateOutputSize(output_properties, 0);
|
||||
return CalculateOutputSize(output_properties, 0);
|
||||
};
|
||||
|
||||
// Validate transfer size.
|
||||
|
||||
@ -423,6 +423,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler/clusters:virtual_cluster",
|
||||
"//tensorflow/core/grappler/costs:graph_memory",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/costs:utils",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
"//tensorflow/core/grappler/utils:traversal",
|
||||
],
|
||||
|
||||
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_memory.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
@ -43,6 +44,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
// Prefix added to nodes which are recomputed.
|
||||
const char* kRecomputedNodePrefix = "Recomputed";
|
||||
const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
|
||||
@ -744,25 +747,6 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static int64 EstimateSize(const OpInfo::TensorProperties& t) {
|
||||
DataType dtype = t.dtype();
|
||||
int64 size = DataTypeSize(dtype);
|
||||
TensorShapeProto shape = t.shape();
|
||||
if (shape.unknown_rank()) {
|
||||
// Can't infer the size if the rank is unknown. It has to be at least a
|
||||
// scalar though.
|
||||
return size;
|
||||
}
|
||||
// If one of the dimensions is unknown statically, assume it's at least one.
|
||||
for (int i = 0; i < shape.dim_size(); ++i) {
|
||||
if (shape.dim(i).size() < 0) {
|
||||
shape.mutable_dim(i)->set_size(1);
|
||||
}
|
||||
}
|
||||
int64 num_elems = TensorShape(shape).num_elements();
|
||||
return num_elems * size;
|
||||
}
|
||||
|
||||
struct SwapInfo {
|
||||
std::vector<int> inputs_to_swap;
|
||||
Costs::NanoSeconds time_to_swap = 0;
|
||||
@ -1149,7 +1133,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
|
||||
int64 bytes_to_swap = 0;
|
||||
for (int64 input_id : swap_info.inputs_to_swap) {
|
||||
const OpInfo::TensorProperties& t = props[input_id];
|
||||
bytes_to_swap += EstimateSize(t);
|
||||
bytes_to_swap += CalculateTensorSize(t);
|
||||
}
|
||||
// Let's assume we're going to swap over PCIe running at 16 GBps.
|
||||
swap_info.time_to_swap = bytes_to_swap / 16;
|
||||
@ -1299,6 +1283,8 @@ Status RelaxAllocatorConstraints(GraphDef* optimized_graph) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
*optimized_graph = item.graph;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user