Update VirtualScheduler constructor.
PiperOrigin-RevId: 241662817
This commit is contained in:
parent
483e187a7b
commit
1597edbeed
@ -123,7 +123,22 @@ AnalyticalCostEstimator::AnalyticalCostEstimator(
|
||||
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
|
||||
scheduler_ = absl::make_unique<VirtualScheduler>(
|
||||
use_static_shapes_, use_aggressive_shape_inference_, cluster,
|
||||
node_manager_.get());
|
||||
node_manager_.get(),
|
||||
absl::make_unique<VirtualPlacer>(cluster->GetDevices()));
|
||||
}
|
||||
|
||||
AnalyticalCostEstimator::AnalyticalCostEstimator(
|
||||
Cluster* cluster, std::unique_ptr<OpLevelCostEstimator> node_estimator,
|
||||
std::unique_ptr<ReadyNodeManager> node_manager,
|
||||
std::unique_ptr<VirtualPlacer> placer, bool use_static_shapes,
|
||||
bool use_aggressive_shape_inference)
|
||||
: node_estimator_(std::move(node_estimator)),
|
||||
node_manager_(std::move(node_manager)),
|
||||
use_static_shapes_(use_static_shapes),
|
||||
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
|
||||
scheduler_ = absl::make_unique<VirtualScheduler>(
|
||||
use_static_shapes_, use_aggressive_shape_inference_, cluster,
|
||||
node_manager_.get(), std::move(placer));
|
||||
}
|
||||
|
||||
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
|
||||
|
@ -47,6 +47,12 @@ class AnalyticalCostEstimator : public CostEstimator {
|
||||
std::unique_ptr<ReadyNodeManager> node_manager,
|
||||
bool use_static_shapes,
|
||||
bool use_aggressive_shape_inference);
|
||||
AnalyticalCostEstimator(Cluster* cluster,
|
||||
std::unique_ptr<OpLevelCostEstimator> node_estimator,
|
||||
std::unique_ptr<ReadyNodeManager> node_manager,
|
||||
std::unique_ptr<VirtualPlacer> placer,
|
||||
bool use_static_shapes,
|
||||
bool use_aggressive_shape_inference);
|
||||
~AnalyticalCostEstimator() override {}
|
||||
|
||||
// Initializes the estimator for the specified grappler item.
|
||||
|
@ -259,13 +259,15 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
|
||||
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference,
|
||||
Cluster* cluster,
|
||||
ReadyNodeManager* ready_nodes)
|
||||
ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<VirtualPlacer> placer)
|
||||
: ready_nodes_(ready_nodes),
|
||||
graph_costs_(Costs::ZeroCosts()),
|
||||
cluster_(cluster),
|
||||
use_static_shapes_(use_static_shapes),
|
||||
use_aggressive_shape_inference_(use_aggressive_shape_inference),
|
||||
placer_(cluster->GetDevices()) {
|
||||
placer_(std::move(placer)) {
|
||||
DCHECK(placer_); // check if the pointer is valid.
|
||||
graph_costs_.num_ops_total = 0;
|
||||
initialized_ = false;
|
||||
track_mem_usage_snapshot_ = VLOG_IS_ON(1);
|
||||
@ -524,13 +526,13 @@ bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
|
||||
}
|
||||
|
||||
string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
||||
return placer_.get_canonical_device_name(*node);
|
||||
return placer_->get_canonical_device_name(*node);
|
||||
}
|
||||
|
||||
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
|
||||
// Replace the ":" characters that may be present in the device name with "_".
|
||||
// This makes it possible to then use the resulting string in a node name.
|
||||
return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":",
|
||||
return str_util::StringReplace(placer_->get_canonical_device_name(*node), ":",
|
||||
"_", true);
|
||||
}
|
||||
|
||||
@ -620,7 +622,7 @@ OpContext VirtualScheduler::GetCurrNode() const {
|
||||
|
||||
// Get the device from the placer.
|
||||
DeviceProperties device;
|
||||
device = placer_.get_device(*node);
|
||||
device = placer_->get_device(*node);
|
||||
|
||||
// Special case for _Send op.
|
||||
if (IsSend(*node)) {
|
||||
|
@ -263,7 +263,9 @@ class VirtualScheduler {
|
||||
// Does not take ownership of cluster or ready_nodes.
|
||||
VirtualScheduler(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||
ReadyNodeManager* ready_nodes);
|
||||
ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<VirtualPlacer> placer);
|
||||
|
||||
// Initializes the scheduler for the specific grappler item.
|
||||
// Should be called immediately after the c'tor or when the scheduler will be
|
||||
// reused for a new grappler item. All internal states of the scheduler
|
||||
@ -356,7 +358,7 @@ class VirtualScheduler {
|
||||
bool track_mem_usage_snapshot_;
|
||||
const bool use_aggressive_shape_inference_;
|
||||
|
||||
VirtualPlacer placer_; // owned.
|
||||
std::unique_ptr<VirtualPlacer> placer_;
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
|
@ -33,8 +33,10 @@ class TestVirtualScheduler : public VirtualScheduler {
|
||||
TestVirtualScheduler(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference,
|
||||
Cluster* cluster)
|
||||
: VirtualScheduler(use_static_shapes, use_aggressive_shape_inference,
|
||||
cluster, &ready_node_manager_) {
|
||||
: VirtualScheduler(
|
||||
use_static_shapes, use_aggressive_shape_inference, cluster,
|
||||
&ready_node_manager_,
|
||||
absl::make_unique<VirtualPlacer>(cluster->GetDevices())) {
|
||||
enable_mem_usage_tracking();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user