[XLA:GPU] Clarify HeuristicLayoutAssignment function.
No functional change. PiperOrigin-RevId: 206090291
This commit is contained in:
parent
b24037513f
commit
5d92abe1e4
@ -52,31 +52,38 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
|
||||
// W <=> X
|
||||
//
|
||||
// Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
|
||||
//
|
||||
// If you have trouble keeping these straight, consider that all that matters
|
||||
// is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
|
||||
|
||||
// As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x
|
||||
// fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version
|
||||
// changes, as well as the hardware updates.
|
||||
constexpr auto kAllNCHW =
|
||||
std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
|
||||
DataLayout::kBatchDepthYX);
|
||||
constexpr auto kAllNHWC =
|
||||
std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
|
||||
DataLayout::kBatchYXDepth);
|
||||
|
||||
// If we're not Volta or not fp16, the decision is easy: Use NCHW.
|
||||
if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 &&
|
||||
IsVoltaOrLater(*stream_executor))) {
|
||||
return std::make_tuple(DataLayout::kBatchDepthYX,
|
||||
FilterLayout::kOutputInputYX,
|
||||
DataLayout::kBatchDepthYX);
|
||||
return kAllNCHW;
|
||||
}
|
||||
|
||||
VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
|
||||
// For BackwardInput that has stride, full NHWC layouts run significantly
|
||||
// slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC).
|
||||
//
|
||||
// TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW,
|
||||
// NHWC).
|
||||
|
||||
// Empirically we've found with Volta and cudnn 7 that backward-input convs
|
||||
// with stride are significantly faster with input in NHWC and filter/output
|
||||
// in NCHW.
|
||||
if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
|
||||
window_util::HasStride(instr->window())) {
|
||||
return std::make_tuple(DataLayout::kBatchYXDepth,
|
||||
FilterLayout::kOutputInputYX,
|
||||
DataLayout::kBatchDepthYX);
|
||||
return std::make_tuple(DataLayout::kBatchYXDepth, // NHWC
|
||||
FilterLayout::kOutputInputYX, // NCHW
|
||||
DataLayout::kBatchDepthYX // NCHW
|
||||
);
|
||||
}
|
||||
return std::make_tuple(DataLayout::kBatchYXDepth,
|
||||
FilterLayout::kOutputYXInput,
|
||||
DataLayout::kBatchYXDepth);
|
||||
|
||||
// For other Volta f16 convolutions, use NHWC.
|
||||
return kAllNHWC;
|
||||
}
|
||||
|
||||
// Adds layout constraints on the cudnn custom-call instruction. The layout
|
||||
|
Loading…
Reference in New Issue
Block a user