Update VirtualScheduler constructor.

PiperOrigin-RevId: 241662817
This commit is contained in:
Lifeng Nai 2019-04-02 21:40:20 -07:00 committed by TensorFlower Gardener
parent 483e187a7b
commit 1597edbeed
5 changed files with 37 additions and 10 deletions

View File

@ -123,7 +123,22 @@ AnalyticalCostEstimator::AnalyticalCostEstimator(
use_aggressive_shape_inference_(use_aggressive_shape_inference) { use_aggressive_shape_inference_(use_aggressive_shape_inference) {
scheduler_ = absl::make_unique<VirtualScheduler>( scheduler_ = absl::make_unique<VirtualScheduler>(
use_static_shapes_, use_aggressive_shape_inference_, cluster, use_static_shapes_, use_aggressive_shape_inference_, cluster,
node_manager_.get()); node_manager_.get(),
absl::make_unique<VirtualPlacer>(cluster->GetDevices()));
}
AnalyticalCostEstimator::AnalyticalCostEstimator(
Cluster* cluster, std::unique_ptr<OpLevelCostEstimator> node_estimator,
std::unique_ptr<ReadyNodeManager> node_manager,
std::unique_ptr<VirtualPlacer> placer, bool use_static_shapes,
bool use_aggressive_shape_inference)
: node_estimator_(std::move(node_estimator)),
node_manager_(std::move(node_manager)),
use_static_shapes_(use_static_shapes),
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
scheduler_ = absl::make_unique<VirtualScheduler>(
use_static_shapes_, use_aggressive_shape_inference_, cluster,
node_manager_.get(), std::move(placer));
} }
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) { Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {

View File

@ -47,6 +47,12 @@ class AnalyticalCostEstimator : public CostEstimator {
std::unique_ptr<ReadyNodeManager> node_manager, std::unique_ptr<ReadyNodeManager> node_manager,
bool use_static_shapes, bool use_static_shapes,
bool use_aggressive_shape_inference); bool use_aggressive_shape_inference);
AnalyticalCostEstimator(Cluster* cluster,
std::unique_ptr<OpLevelCostEstimator> node_estimator,
std::unique_ptr<ReadyNodeManager> node_manager,
std::unique_ptr<VirtualPlacer> placer,
bool use_static_shapes,
bool use_aggressive_shape_inference);
~AnalyticalCostEstimator() override {} ~AnalyticalCostEstimator() override {}
// Initializes the estimator for the specified grappler item. // Initializes the estimator for the specified grappler item.

View File

@ -259,13 +259,15 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
VirtualScheduler::VirtualScheduler(const bool use_static_shapes, VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference, const bool use_aggressive_shape_inference,
Cluster* cluster, Cluster* cluster,
ReadyNodeManager* ready_nodes) ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer)
: ready_nodes_(ready_nodes), : ready_nodes_(ready_nodes),
graph_costs_(Costs::ZeroCosts()), graph_costs_(Costs::ZeroCosts()),
cluster_(cluster), cluster_(cluster),
use_static_shapes_(use_static_shapes), use_static_shapes_(use_static_shapes),
use_aggressive_shape_inference_(use_aggressive_shape_inference), use_aggressive_shape_inference_(use_aggressive_shape_inference),
placer_(cluster->GetDevices()) { placer_(std::move(placer)) {
DCHECK(placer_); // check if the pointer is valid.
graph_costs_.num_ops_total = 0; graph_costs_.num_ops_total = 0;
initialized_ = false; initialized_ = false;
track_mem_usage_snapshot_ = VLOG_IS_ON(1); track_mem_usage_snapshot_ = VLOG_IS_ON(1);
@ -524,13 +526,13 @@ bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
} }
string VirtualScheduler::DeviceName(const NodeDef* node) const { string VirtualScheduler::DeviceName(const NodeDef* node) const {
return placer_.get_canonical_device_name(*node); return placer_->get_canonical_device_name(*node);
} }
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const { string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
// Replace the ":" characters that may be present in the device name with "_". // Replace the ":" characters that may be present in the device name with "_".
// This makes it possible to then use the resulting string in a node name. // This makes it possible to then use the resulting string in a node name.
return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":", return str_util::StringReplace(placer_->get_canonical_device_name(*node), ":",
"_", true); "_", true);
} }
@ -620,7 +622,7 @@ OpContext VirtualScheduler::GetCurrNode() const {
// Get the device from the placer. // Get the device from the placer.
DeviceProperties device; DeviceProperties device;
device = placer_.get_device(*node); device = placer_->get_device(*node);
// Special case for _Send op. // Special case for _Send op.
if (IsSend(*node)) { if (IsSend(*node)) {

View File

@ -263,7 +263,9 @@ class VirtualScheduler {
// Does not take ownership of cluster or ready_nodes. // Does not take ownership of cluster or ready_nodes.
VirtualScheduler(const bool use_static_shapes, VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference, Cluster* cluster, const bool use_aggressive_shape_inference, Cluster* cluster,
ReadyNodeManager* ready_nodes); ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer);
// Initializes the scheduler for the specific grappler item. // Initializes the scheduler for the specific grappler item.
// Should be called immediately after the c'tor or when the scheduler will be // Should be called immediately after the c'tor or when the scheduler will be
// reused for a new grappler item. All internal states of the scheduler // reused for a new grappler item. All internal states of the scheduler
@ -356,7 +358,7 @@ class VirtualScheduler {
bool track_mem_usage_snapshot_; bool track_mem_usage_snapshot_;
const bool use_aggressive_shape_inference_; const bool use_aggressive_shape_inference_;
VirtualPlacer placer_; // owned. std::unique_ptr<VirtualPlacer> placer_;
}; };
} // namespace grappler } // namespace grappler

View File

@ -33,8 +33,10 @@ class TestVirtualScheduler : public VirtualScheduler {
TestVirtualScheduler(const bool use_static_shapes, TestVirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference, const bool use_aggressive_shape_inference,
Cluster* cluster) Cluster* cluster)
: VirtualScheduler(use_static_shapes, use_aggressive_shape_inference, : VirtualScheduler(
cluster, &ready_node_manager_) { use_static_shapes, use_aggressive_shape_inference, cluster,
&ready_node_manager_,
absl::make_unique<VirtualPlacer>(cluster->GetDevices())) {
enable_mem_usage_tracking(); enable_mem_usage_tracking();
} }