Set default device info if device type is neither CPU nor GPU.

OpLevelCostEstimator::GetDeviceInfo() returns GFLOPs/sec and GB/sec of
the assumed machine, but currently it assumes the device type is either CPU
or GPU. We didn't expect other device types would be passed during the Grappler
optimization as transfer operations such as _Send and _Recv are added when
a graph actually runs. In some occasions, we encountered the input graph
includes _HostSend op (it is possible that the graph includes _Send or _Recv
ops; that happened in some test cases).

Device type not CPU / GPU currently causes crash; to avoid such a case, we
assume unknown device type as transfer operations over PCIe; setting
default PCIe x16 gen3 bandwidth to avoid crash.

As the GetDeviceInfo() method is virtual, one may override this method
if more precise device info is needed.

PiperOrigin-RevId: 356865551
Change-Id: Icf1f39e5da81f9c627aa01e479f437184a13eed3
This commit is contained in:
Doe Hyun Yoon 2021-02-10 17:26:46 -08:00 committed by TensorFlower Gardener
parent ba45035f3a
commit 05756fcc81

View File

@ -714,6 +714,8 @@ Status OpLevelCostEstimator::PredictNodeCosts(const OpContext& op_context,
return PredictCostOfAnUnknownOp(op_context, node_costs);
}
// This method assumes a typical system composed of CPUs and GPUs, connected
// through PCIe. To define device info more precisely, override this method.
DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
const DeviceProperties& device) const {
double gflops = -1;
@ -755,13 +757,15 @@ DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
} else {
gb_per_sec = 100;
}
} else {
LOG_EVERY_N(WARNING, 1000) << "Unknown device type: " << device.type()
<< ", assuming PCIe between CPU and GPU.";
gflops = 1; // Dummy value; data transfer ops would not have compute ops.
gb_per_sec = 12; // default PCIe x16 gen3.
}
VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
<< " gb_per_sec: " << gb_per_sec;
DCHECK_LT(0, gflops) << device.DebugString();
DCHECK_LT(0, gb_per_sec) << device.DebugString();
return DeviceInfo(gflops, gb_per_sec);
}