update VirtualPlacer constructor interface.

PiperOrigin-RevId: 240923787
This commit is contained in:
Lifeng Nai 2019-03-28 22:54:35 -07:00 committed by TensorFlower Gardener
parent d11a7f8d20
commit f089b3180e
7 changed files with 23 additions and 25 deletions

View File

@ -23,14 +23,12 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
VirtualPlacer::VirtualPlacer(const Cluster* cluster) {
CHECK(cluster);
// Default job name for canonical device name. Needs to be set before the
// first call to to_lfqn_or_empty()
default_job_name_lowercase_ = "localhost";
devices_ = cluster->GetDevices();
VirtualPlacer::VirtualPlacer(
const std::unordered_map<string, DeviceProperties>& devices)
: devices_(devices),
// Default job name for canonical device name. Needs to be set before the
// first call to to_lfqn_or_empty()
default_job_name_lowercase_("localhost") {
lfqn_map_.reserve(devices_.size());
for (const auto& kv : devices_) {
const auto lfqn = to_lfqn_or_empty(kv.first);

View File

@ -28,7 +28,7 @@ class Cluster;
// The virtual placer emulates the behavior of the TF placer.
class VirtualPlacer {
public:
VirtualPlacer(const Cluster* cluster);
VirtualPlacer(const std::unordered_map<string, DeviceProperties>& devices);
const DeviceProperties& get_device(const NodeDef& node) const;

View File

@ -33,7 +33,7 @@ TEST(VirtualPlacerTest, LocalDevices) {
gpu_device.set_type("GPU");
devices["/job:localhost/replica:0/task:0/device:GPU:0"] = gpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
VirtualPlacer placer(devices);
NodeDef node;
node.set_op("Conv2D");
@ -63,7 +63,7 @@ TEST(VirtualPlacerTest, ShortNames) {
gpu_device.set_type("GPU");
devices["/GPU:0"] = gpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
VirtualPlacer placer(devices);
NodeDef node;
node.set_op("Conv2D");
@ -93,7 +93,7 @@ TEST(VirtualPlacerTest, PlacementOnNonDefaultDevice) {
tpu_device.set_type("TPU");
devices["/job:localhost/replica:0/task:0/device:TPU:0"] = tpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
VirtualPlacer placer(devices);
NodeDef node;
node.set_op("Conv2D");
@ -123,7 +123,7 @@ TEST(VirtualPlacerTest, EmptyJobName) {
devices[strings::StrCat("/job:", job_name,
"/replica:0/task:0/device:GPU:0")] = gpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
VirtualPlacer placer(devices);
NodeDef node;
node.set_op("Conv2D");
@ -145,7 +145,7 @@ TEST(VirtualPlacerTest, EmptyJobName) {
devices["/job:ps/replica:0/task:0/cpu:0"] = cpu_device;
devices["/job:worker/replica:0/task:0/cpu:0"] = cpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
VirtualPlacer placer(devices);
NodeDef node;
node.set_op("Conv2D");
@ -157,7 +157,7 @@ TEST(VirtualPlacerTest, EmptyJobName) {
string GetDefaultDeviceName(
const std::unordered_map<string, DeviceProperties>& devices) {
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
VirtualPlacer placer(devices);
NodeDef node;
node.set_op("Conv2D");
// Device is not set to the node, so get_canonical_device_name() will return
@ -204,7 +204,7 @@ TEST(VirtualPlacerTest, MultiReplica) {
}
std::unique_ptr<VirtualCluster> cluster(new VirtualCluster(devices));
std::unique_ptr<VirtualPlacer> placer(new VirtualPlacer(cluster.get()));
std::unique_ptr<VirtualPlacer> placer(new VirtualPlacer(devices));
auto get_device_name = [&placer](const string& device) -> string {
NodeDef node;
@ -235,7 +235,7 @@ TEST(VirtualPlacerTest, MultiReplica) {
cpu_device;
}
cluster.reset(new VirtualCluster(devices));
placer.reset(new VirtualPlacer(cluster.get()));
placer.reset(new VirtualPlacer(cluster->GetDevices()));
EXPECT_EQ("/job:worker/replica:0/task:0/cpu:0",
get_device_name("/job:worker/replica:0/cpu:0"));
EXPECT_EQ("/job:worker/replica:7/task:0/gpu:3",
@ -255,7 +255,7 @@ TEST(VirtualPlacerTest, FallBackUnknown) {
// cluster.
std::unordered_map<string, DeviceProperties> devices;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
VirtualPlacer placer(devices);
NodeDef node;
node.set_op("Conv2D");
@ -271,7 +271,7 @@ TEST(VirtualPlacerTest, FallBackCPU) {
cpu_device.set_type("CPU");
devices["/job:my_job/replica:0/task:0/cpu:0"] = cpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
VirtualPlacer placer(devices);
NodeDef node;
node.set_op("Conv2D");
@ -291,7 +291,7 @@ TEST(VirtualPlacerTest, RemoteDevices) {
gpu_device.set_type("GPU");
devices["/job:my_job/replica:0/task:0/device:GPU:0"] = gpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
VirtualPlacer placer(devices);
NodeDef node;
node.set_op("Conv2D");

View File

@ -265,7 +265,7 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
cluster_(cluster),
use_static_shapes_(use_static_shapes),
use_aggressive_shape_inference_(use_aggressive_shape_inference),
placer_(cluster) {
placer_(cluster->GetDevices()) {
graph_costs_.num_ops_total = 0;
initialized_ = false;
track_mem_usage_snapshot_ = VLOG_IS_ON(1);

View File

@ -2207,7 +2207,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
return Status::OK();
}
virtual_placer_.reset(new VirtualPlacer(cluster));
virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices()));
nodes_to_preserve_ = item.NodesToPreserve();
GraphProperties graph_properties(item);
auto status = graph_properties.InferStatically(false);

View File

@ -1223,7 +1223,7 @@ TEST_F(LayoutOptimizerTest, DevicePlacement) {
auto i = ops::Identity(s.WithOpName("i"), shape);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
VirtualPlacer virtual_placer(virtual_cluster_.get());
VirtualPlacer virtual_placer(virtual_cluster_->GetDevices());
for (auto& node : *item.graph.mutable_node()) {
string device = virtual_placer.get_canonical_device_name(node);
node.set_device(device);

View File

@ -94,7 +94,7 @@ Status EstimateEarliestExecutionTimes(
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster);
VirtualPlacer placer(cluster->GetDevices());
while (!ready_nodes.empty()) {
const NodeDef* node = ready_nodes.front();
@ -162,7 +162,7 @@ Status EstimateRequiredTimes(
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster);
VirtualPlacer placer(cluster->GetDevices());
while (!ready_nodes.empty()) {
const NodeDef* node = ready_nodes.front();