Update VirtualScheduler constructor.
PiperOrigin-RevId: 241662817
This commit is contained in:
parent
483e187a7b
commit
1597edbeed
@ -123,7 +123,22 @@ AnalyticalCostEstimator::AnalyticalCostEstimator(
|
|||||||
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
|
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
|
||||||
scheduler_ = absl::make_unique<VirtualScheduler>(
|
scheduler_ = absl::make_unique<VirtualScheduler>(
|
||||||
use_static_shapes_, use_aggressive_shape_inference_, cluster,
|
use_static_shapes_, use_aggressive_shape_inference_, cluster,
|
||||||
node_manager_.get());
|
node_manager_.get(),
|
||||||
|
absl::make_unique<VirtualPlacer>(cluster->GetDevices()));
|
||||||
|
}
|
||||||
|
|
||||||
|
AnalyticalCostEstimator::AnalyticalCostEstimator(
|
||||||
|
Cluster* cluster, std::unique_ptr<OpLevelCostEstimator> node_estimator,
|
||||||
|
std::unique_ptr<ReadyNodeManager> node_manager,
|
||||||
|
std::unique_ptr<VirtualPlacer> placer, bool use_static_shapes,
|
||||||
|
bool use_aggressive_shape_inference)
|
||||||
|
: node_estimator_(std::move(node_estimator)),
|
||||||
|
node_manager_(std::move(node_manager)),
|
||||||
|
use_static_shapes_(use_static_shapes),
|
||||||
|
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
|
||||||
|
scheduler_ = absl::make_unique<VirtualScheduler>(
|
||||||
|
use_static_shapes_, use_aggressive_shape_inference_, cluster,
|
||||||
|
node_manager_.get(), std::move(placer));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
|
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
|
||||||
|
|||||||
@ -47,6 +47,12 @@ class AnalyticalCostEstimator : public CostEstimator {
|
|||||||
std::unique_ptr<ReadyNodeManager> node_manager,
|
std::unique_ptr<ReadyNodeManager> node_manager,
|
||||||
bool use_static_shapes,
|
bool use_static_shapes,
|
||||||
bool use_aggressive_shape_inference);
|
bool use_aggressive_shape_inference);
|
||||||
|
AnalyticalCostEstimator(Cluster* cluster,
|
||||||
|
std::unique_ptr<OpLevelCostEstimator> node_estimator,
|
||||||
|
std::unique_ptr<ReadyNodeManager> node_manager,
|
||||||
|
std::unique_ptr<VirtualPlacer> placer,
|
||||||
|
bool use_static_shapes,
|
||||||
|
bool use_aggressive_shape_inference);
|
||||||
~AnalyticalCostEstimator() override {}
|
~AnalyticalCostEstimator() override {}
|
||||||
|
|
||||||
// Initializes the estimator for the specified grappler item.
|
// Initializes the estimator for the specified grappler item.
|
||||||
|
|||||||
@ -259,13 +259,15 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
|
|||||||
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||||
const bool use_aggressive_shape_inference,
|
const bool use_aggressive_shape_inference,
|
||||||
Cluster* cluster,
|
Cluster* cluster,
|
||||||
ReadyNodeManager* ready_nodes)
|
ReadyNodeManager* ready_nodes,
|
||||||
|
std::unique_ptr<VirtualPlacer> placer)
|
||||||
: ready_nodes_(ready_nodes),
|
: ready_nodes_(ready_nodes),
|
||||||
graph_costs_(Costs::ZeroCosts()),
|
graph_costs_(Costs::ZeroCosts()),
|
||||||
cluster_(cluster),
|
cluster_(cluster),
|
||||||
use_static_shapes_(use_static_shapes),
|
use_static_shapes_(use_static_shapes),
|
||||||
use_aggressive_shape_inference_(use_aggressive_shape_inference),
|
use_aggressive_shape_inference_(use_aggressive_shape_inference),
|
||||||
placer_(cluster->GetDevices()) {
|
placer_(std::move(placer)) {
|
||||||
|
DCHECK(placer_); // check if the pointer is valid.
|
||||||
graph_costs_.num_ops_total = 0;
|
graph_costs_.num_ops_total = 0;
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
track_mem_usage_snapshot_ = VLOG_IS_ON(1);
|
track_mem_usage_snapshot_ = VLOG_IS_ON(1);
|
||||||
@ -524,13 +526,13 @@ bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
||||||
return placer_.get_canonical_device_name(*node);
|
return placer_->get_canonical_device_name(*node);
|
||||||
}
|
}
|
||||||
|
|
||||||
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
|
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
|
||||||
// Replace the ":" characters that may be present in the device name with "_".
|
// Replace the ":" characters that may be present in the device name with "_".
|
||||||
// This makes it possible to then use the resulting string in a node name.
|
// This makes it possible to then use the resulting string in a node name.
|
||||||
return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":",
|
return str_util::StringReplace(placer_->get_canonical_device_name(*node), ":",
|
||||||
"_", true);
|
"_", true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -620,7 +622,7 @@ OpContext VirtualScheduler::GetCurrNode() const {
|
|||||||
|
|
||||||
// Get the device from the placer.
|
// Get the device from the placer.
|
||||||
DeviceProperties device;
|
DeviceProperties device;
|
||||||
device = placer_.get_device(*node);
|
device = placer_->get_device(*node);
|
||||||
|
|
||||||
// Special case for _Send op.
|
// Special case for _Send op.
|
||||||
if (IsSend(*node)) {
|
if (IsSend(*node)) {
|
||||||
|
|||||||
@ -263,7 +263,9 @@ class VirtualScheduler {
|
|||||||
// Does not take ownership of cluster or ready_nodes.
|
// Does not take ownership of cluster or ready_nodes.
|
||||||
VirtualScheduler(const bool use_static_shapes,
|
VirtualScheduler(const bool use_static_shapes,
|
||||||
const bool use_aggressive_shape_inference, Cluster* cluster,
|
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||||
ReadyNodeManager* ready_nodes);
|
ReadyNodeManager* ready_nodes,
|
||||||
|
std::unique_ptr<VirtualPlacer> placer);
|
||||||
|
|
||||||
// Initializes the scheduler for the specific grappler item.
|
// Initializes the scheduler for the specific grappler item.
|
||||||
// Should be called immediately after the c'tor or when the scheduler will be
|
// Should be called immediately after the c'tor or when the scheduler will be
|
||||||
// reused for a new grappler item. All internal states of the scheduler
|
// reused for a new grappler item. All internal states of the scheduler
|
||||||
@ -356,7 +358,7 @@ class VirtualScheduler {
|
|||||||
bool track_mem_usage_snapshot_;
|
bool track_mem_usage_snapshot_;
|
||||||
const bool use_aggressive_shape_inference_;
|
const bool use_aggressive_shape_inference_;
|
||||||
|
|
||||||
VirtualPlacer placer_; // owned.
|
std::unique_ptr<VirtualPlacer> placer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
|
|||||||
@ -33,8 +33,10 @@ class TestVirtualScheduler : public VirtualScheduler {
|
|||||||
TestVirtualScheduler(const bool use_static_shapes,
|
TestVirtualScheduler(const bool use_static_shapes,
|
||||||
const bool use_aggressive_shape_inference,
|
const bool use_aggressive_shape_inference,
|
||||||
Cluster* cluster)
|
Cluster* cluster)
|
||||||
: VirtualScheduler(use_static_shapes, use_aggressive_shape_inference,
|
: VirtualScheduler(
|
||||||
cluster, &ready_node_manager_) {
|
use_static_shapes, use_aggressive_shape_inference, cluster,
|
||||||
|
&ready_node_manager_,
|
||||||
|
absl::make_unique<VirtualPlacer>(cluster->GetDevices())) {
|
||||||
enable_mem_usage_tracking();
|
enable_mem_usage_tracking();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user