update VirtualPlacer constructor interface.
PiperOrigin-RevId: 240923787
This commit is contained in:
parent
d11a7f8d20
commit
f089b3180e
@ -23,14 +23,12 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
VirtualPlacer::VirtualPlacer(const Cluster* cluster) {
|
||||
CHECK(cluster);
|
||||
|
||||
// Default job name for canonical device name. Needs to be set before the
|
||||
// first call to to_lfqn_or_empty()
|
||||
default_job_name_lowercase_ = "localhost";
|
||||
|
||||
devices_ = cluster->GetDevices();
|
||||
VirtualPlacer::VirtualPlacer(
|
||||
const std::unordered_map<string, DeviceProperties>& devices)
|
||||
: devices_(devices),
|
||||
// Default job name for canonical device name. Needs to be set before the
|
||||
// first call to to_lfqn_or_empty()
|
||||
default_job_name_lowercase_("localhost") {
|
||||
lfqn_map_.reserve(devices_.size());
|
||||
for (const auto& kv : devices_) {
|
||||
const auto lfqn = to_lfqn_or_empty(kv.first);
|
||||
|
@ -28,7 +28,7 @@ class Cluster;
|
||||
// The virtual placer emulates the behavior of the TF placer.
|
||||
class VirtualPlacer {
|
||||
public:
|
||||
VirtualPlacer(const Cluster* cluster);
|
||||
VirtualPlacer(const std::unordered_map<string, DeviceProperties>& devices);
|
||||
|
||||
const DeviceProperties& get_device(const NodeDef& node) const;
|
||||
|
||||
|
@ -33,7 +33,7 @@ TEST(VirtualPlacerTest, LocalDevices) {
|
||||
gpu_device.set_type("GPU");
|
||||
devices["/job:localhost/replica:0/task:0/device:GPU:0"] = gpu_device;
|
||||
VirtualCluster cluster(devices);
|
||||
VirtualPlacer placer(&cluster);
|
||||
VirtualPlacer placer(devices);
|
||||
|
||||
NodeDef node;
|
||||
node.set_op("Conv2D");
|
||||
@ -63,7 +63,7 @@ TEST(VirtualPlacerTest, ShortNames) {
|
||||
gpu_device.set_type("GPU");
|
||||
devices["/GPU:0"] = gpu_device;
|
||||
VirtualCluster cluster(devices);
|
||||
VirtualPlacer placer(&cluster);
|
||||
VirtualPlacer placer(devices);
|
||||
|
||||
NodeDef node;
|
||||
node.set_op("Conv2D");
|
||||
@ -93,7 +93,7 @@ TEST(VirtualPlacerTest, PlacementOnNonDefaultDevice) {
|
||||
tpu_device.set_type("TPU");
|
||||
devices["/job:localhost/replica:0/task:0/device:TPU:0"] = tpu_device;
|
||||
VirtualCluster cluster(devices);
|
||||
VirtualPlacer placer(&cluster);
|
||||
VirtualPlacer placer(devices);
|
||||
|
||||
NodeDef node;
|
||||
node.set_op("Conv2D");
|
||||
@ -123,7 +123,7 @@ TEST(VirtualPlacerTest, EmptyJobName) {
|
||||
devices[strings::StrCat("/job:", job_name,
|
||||
"/replica:0/task:0/device:GPU:0")] = gpu_device;
|
||||
VirtualCluster cluster(devices);
|
||||
VirtualPlacer placer(&cluster);
|
||||
VirtualPlacer placer(devices);
|
||||
|
||||
NodeDef node;
|
||||
node.set_op("Conv2D");
|
||||
@ -145,7 +145,7 @@ TEST(VirtualPlacerTest, EmptyJobName) {
|
||||
devices["/job:ps/replica:0/task:0/cpu:0"] = cpu_device;
|
||||
devices["/job:worker/replica:0/task:0/cpu:0"] = cpu_device;
|
||||
VirtualCluster cluster(devices);
|
||||
VirtualPlacer placer(&cluster);
|
||||
VirtualPlacer placer(devices);
|
||||
|
||||
NodeDef node;
|
||||
node.set_op("Conv2D");
|
||||
@ -157,7 +157,7 @@ TEST(VirtualPlacerTest, EmptyJobName) {
|
||||
string GetDefaultDeviceName(
|
||||
const std::unordered_map<string, DeviceProperties>& devices) {
|
||||
VirtualCluster cluster(devices);
|
||||
VirtualPlacer placer(&cluster);
|
||||
VirtualPlacer placer(devices);
|
||||
NodeDef node;
|
||||
node.set_op("Conv2D");
|
||||
// Device is not set to the node, so get_canonical_device_name() will return
|
||||
@ -204,7 +204,7 @@ TEST(VirtualPlacerTest, MultiReplica) {
|
||||
}
|
||||
|
||||
std::unique_ptr<VirtualCluster> cluster(new VirtualCluster(devices));
|
||||
std::unique_ptr<VirtualPlacer> placer(new VirtualPlacer(cluster.get()));
|
||||
std::unique_ptr<VirtualPlacer> placer(new VirtualPlacer(devices));
|
||||
|
||||
auto get_device_name = [&placer](const string& device) -> string {
|
||||
NodeDef node;
|
||||
@ -235,7 +235,7 @@ TEST(VirtualPlacerTest, MultiReplica) {
|
||||
cpu_device;
|
||||
}
|
||||
cluster.reset(new VirtualCluster(devices));
|
||||
placer.reset(new VirtualPlacer(cluster.get()));
|
||||
placer.reset(new VirtualPlacer(cluster->GetDevices()));
|
||||
EXPECT_EQ("/job:worker/replica:0/task:0/cpu:0",
|
||||
get_device_name("/job:worker/replica:0/cpu:0"));
|
||||
EXPECT_EQ("/job:worker/replica:7/task:0/gpu:3",
|
||||
@ -255,7 +255,7 @@ TEST(VirtualPlacerTest, FallBackUnknown) {
|
||||
// cluster.
|
||||
std::unordered_map<string, DeviceProperties> devices;
|
||||
VirtualCluster cluster(devices);
|
||||
VirtualPlacer placer(&cluster);
|
||||
VirtualPlacer placer(devices);
|
||||
|
||||
NodeDef node;
|
||||
node.set_op("Conv2D");
|
||||
@ -271,7 +271,7 @@ TEST(VirtualPlacerTest, FallBackCPU) {
|
||||
cpu_device.set_type("CPU");
|
||||
devices["/job:my_job/replica:0/task:0/cpu:0"] = cpu_device;
|
||||
VirtualCluster cluster(devices);
|
||||
VirtualPlacer placer(&cluster);
|
||||
VirtualPlacer placer(devices);
|
||||
|
||||
NodeDef node;
|
||||
node.set_op("Conv2D");
|
||||
@ -291,7 +291,7 @@ TEST(VirtualPlacerTest, RemoteDevices) {
|
||||
gpu_device.set_type("GPU");
|
||||
devices["/job:my_job/replica:0/task:0/device:GPU:0"] = gpu_device;
|
||||
VirtualCluster cluster(devices);
|
||||
VirtualPlacer placer(&cluster);
|
||||
VirtualPlacer placer(devices);
|
||||
|
||||
NodeDef node;
|
||||
node.set_op("Conv2D");
|
||||
|
@ -265,7 +265,7 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||
cluster_(cluster),
|
||||
use_static_shapes_(use_static_shapes),
|
||||
use_aggressive_shape_inference_(use_aggressive_shape_inference),
|
||||
placer_(cluster) {
|
||||
placer_(cluster->GetDevices()) {
|
||||
graph_costs_.num_ops_total = 0;
|
||||
initialized_ = false;
|
||||
track_mem_usage_snapshot_ = VLOG_IS_ON(1);
|
||||
|
@ -2207,7 +2207,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual_placer_.reset(new VirtualPlacer(cluster));
|
||||
virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices()));
|
||||
nodes_to_preserve_ = item.NodesToPreserve();
|
||||
GraphProperties graph_properties(item);
|
||||
auto status = graph_properties.InferStatically(false);
|
||||
|
@ -1223,7 +1223,7 @@ TEST_F(LayoutOptimizerTest, DevicePlacement) {
|
||||
auto i = ops::Identity(s.WithOpName("i"), shape);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
VirtualPlacer virtual_placer(virtual_cluster_.get());
|
||||
VirtualPlacer virtual_placer(virtual_cluster_->GetDevices());
|
||||
for (auto& node : *item.graph.mutable_node()) {
|
||||
string device = virtual_placer.get_canonical_device_name(node);
|
||||
node.set_device(device);
|
||||
|
@ -94,7 +94,7 @@ Status EstimateEarliestExecutionTimes(
|
||||
GraphProperties properties(item);
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically(true));
|
||||
OpLevelCostEstimator estimator;
|
||||
VirtualPlacer placer(cluster);
|
||||
VirtualPlacer placer(cluster->GetDevices());
|
||||
|
||||
while (!ready_nodes.empty()) {
|
||||
const NodeDef* node = ready_nodes.front();
|
||||
@ -162,7 +162,7 @@ Status EstimateRequiredTimes(
|
||||
GraphProperties properties(item);
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically(true));
|
||||
OpLevelCostEstimator estimator;
|
||||
VirtualPlacer placer(cluster);
|
||||
VirtualPlacer placer(cluster->GetDevices());
|
||||
|
||||
while (!ready_nodes.empty()) {
|
||||
const NodeDef* node = ready_nodes.front();
|
||||
|
Loading…
Reference in New Issue
Block a user