Set default device info if device type is neither CPU nor GPU.
OpLevelCostEstimator::GetDeviceInfo() returns GFLOPs/sec and GB/sec of the assumed machine, but currently it assumes the device type is either CPU or GPU. We didn't expect other device types would be passed during the Grappler optimization as transfer operations such as _Send and _Recv are added when a graph actually runs. In some occasions, we encountered the input graph includes _HostSend op (it is possible that the graph includes _Send or _Recv ops; that happened in some test cases). Device type not CPU / GPU currently causes crash; to avoid such a case, we assume unknown device type as transfer operations over PCIe; setting default PCIe x16 gen3 bandwidth to avoid crash. As the GetDeviceInfo() method is virtual, one may override this method if more precise device info is needed. PiperOrigin-RevId: 356865551 Change-Id: Icf1f39e5da81f9c627aa01e479f437184a13eed3
This commit is contained in:
parent
ba45035f3a
commit
05756fcc81
@ -714,6 +714,8 @@ Status OpLevelCostEstimator::PredictNodeCosts(const OpContext& op_context,
|
||||
return PredictCostOfAnUnknownOp(op_context, node_costs);
|
||||
}
|
||||
|
||||
// This method assumes a typical system composed of CPUs and GPUs, connected
|
||||
// through PCIe. To define device info more precisely, override this method.
|
||||
DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
|
||||
const DeviceProperties& device) const {
|
||||
double gflops = -1;
|
||||
@ -755,13 +757,15 @@ DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
|
||||
} else {
|
||||
gb_per_sec = 100;
|
||||
}
|
||||
} else {
|
||||
LOG_EVERY_N(WARNING, 1000) << "Unknown device type: " << device.type()
|
||||
<< ", assuming PCIe between CPU and GPU.";
|
||||
gflops = 1; // Dummy value; data transfer ops would not have compute ops.
|
||||
gb_per_sec = 12; // default PCIe x16 gen3.
|
||||
}
|
||||
VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
|
||||
<< " gb_per_sec: " << gb_per_sec;
|
||||
|
||||
DCHECK_LT(0, gflops) << device.DebugString();
|
||||
DCHECK_LT(0, gb_per_sec) << device.DebugString();
|
||||
|
||||
return DeviceInfo(gflops, gb_per_sec);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user