Update VirtualScheduler constructor.

PiperOrigin-RevId: 241662817
This commit is contained in:
Lifeng Nai 2019-04-02 21:40:20 -07:00 committed by TensorFlower Gardener
parent 483e187a7b
commit 1597edbeed
5 changed files with 37 additions and 10 deletions

View File

@ -123,7 +123,22 @@ AnalyticalCostEstimator::AnalyticalCostEstimator(
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
scheduler_ = absl::make_unique<VirtualScheduler>(
use_static_shapes_, use_aggressive_shape_inference_, cluster,
node_manager_.get());
node_manager_.get(),
absl::make_unique<VirtualPlacer>(cluster->GetDevices()));
}
AnalyticalCostEstimator::AnalyticalCostEstimator(
Cluster* cluster, std::unique_ptr<OpLevelCostEstimator> node_estimator,
std::unique_ptr<ReadyNodeManager> node_manager,
std::unique_ptr<VirtualPlacer> placer, bool use_static_shapes,
bool use_aggressive_shape_inference)
: node_estimator_(std::move(node_estimator)),
node_manager_(std::move(node_manager)),
use_static_shapes_(use_static_shapes),
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
scheduler_ = absl::make_unique<VirtualScheduler>(
use_static_shapes_, use_aggressive_shape_inference_, cluster,
node_manager_.get(), std::move(placer));
}
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {

View File

@ -47,6 +47,12 @@ class AnalyticalCostEstimator : public CostEstimator {
std::unique_ptr<ReadyNodeManager> node_manager,
bool use_static_shapes,
bool use_aggressive_shape_inference);
AnalyticalCostEstimator(Cluster* cluster,
std::unique_ptr<OpLevelCostEstimator> node_estimator,
std::unique_ptr<ReadyNodeManager> node_manager,
std::unique_ptr<VirtualPlacer> placer,
bool use_static_shapes,
bool use_aggressive_shape_inference);
~AnalyticalCostEstimator() override {}
// Initializes the estimator for the specified grappler item.

View File

@ -259,13 +259,15 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference,
Cluster* cluster,
ReadyNodeManager* ready_nodes)
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer)
: ready_nodes_(ready_nodes),
graph_costs_(Costs::ZeroCosts()),
cluster_(cluster),
use_static_shapes_(use_static_shapes),
use_aggressive_shape_inference_(use_aggressive_shape_inference),
placer_(cluster->GetDevices()) {
placer_(std::move(placer)) {
DCHECK(placer_); // check if the pointer is valid.
graph_costs_.num_ops_total = 0;
initialized_ = false;
track_mem_usage_snapshot_ = VLOG_IS_ON(1);
@ -524,13 +526,13 @@ bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
}
string VirtualScheduler::DeviceName(const NodeDef* node) const {
return placer_.get_canonical_device_name(*node);
return placer_->get_canonical_device_name(*node);
}
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
// Replace the ":" characters that may be present in the device name with "_".
// This makes it possible to then use the resulting string in a node name.
return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":",
return str_util::StringReplace(placer_->get_canonical_device_name(*node), ":",
"_", true);
}
@ -620,7 +622,7 @@ OpContext VirtualScheduler::GetCurrNode() const {
// Get the device from the placer.
DeviceProperties device;
device = placer_.get_device(*node);
device = placer_->get_device(*node);
// Special case for _Send op.
if (IsSend(*node)) {

View File

@ -263,7 +263,9 @@ class VirtualScheduler {
// Does not take ownership of cluster or ready_nodes.
VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference, Cluster* cluster,
ReadyNodeManager* ready_nodes);
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer);
// Initializes the scheduler for the specific grappler item.
// Should be called immediately after the c'tor or when the scheduler will be
// reused for a new grappler item. All internal states of the scheduler
@ -356,7 +358,7 @@ class VirtualScheduler {
bool track_mem_usage_snapshot_;
const bool use_aggressive_shape_inference_;
VirtualPlacer placer_; // owned.
std::unique_ptr<VirtualPlacer> placer_;
};
} // namespace grappler

View File

@ -33,8 +33,10 @@ class TestVirtualScheduler : public VirtualScheduler {
TestVirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference,
Cluster* cluster)
: VirtualScheduler(use_static_shapes, use_aggressive_shape_inference,
cluster, &ready_node_manager_) {
: VirtualScheduler(
use_static_shapes, use_aggressive_shape_inference, cluster,
&ready_node_manager_,
absl::make_unique<VirtualPlacer>(cluster->GetDevices())) {
enable_mem_usage_tracking();
}