[XLA:GPU] Clarify HeuristicLayoutAssignment function.

No functional change.

PiperOrigin-RevId: 206090291
This commit is contained in:
Justin Lebar 2018-07-25 18:32:48 -07:00 committed by TensorFlower Gardener
parent b24037513f
commit 5d92abe1e4

View File

@ -52,31 +52,38 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
// W <=> X
//
// Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
//
// If you have trouble keeping these straight, consider that all that matters
// is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
// As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x
// fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version
// changes, as well as the hardware updates.
constexpr auto kAllNCHW =
std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
DataLayout::kBatchDepthYX);
constexpr auto kAllNHWC =
std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
DataLayout::kBatchYXDepth);
// If we're not Volta or not fp16, the decision is easy: Use NCHW.
if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 &&
IsVoltaOrLater(*stream_executor))) {
return std::make_tuple(DataLayout::kBatchDepthYX,
FilterLayout::kOutputInputYX,
DataLayout::kBatchDepthYX);
return kAllNCHW;
}
VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
// For BackwardInput that has stride, full NHWC layouts run significantly
// slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC).
//
// TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW,
// NHWC).
// Empirically we've found with Volta and cudnn 7 that backward-input convs
// with stride are significantly faster with input in NHWC and filter/output
// in NCHW.
if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
window_util::HasStride(instr->window())) {
return std::make_tuple(DataLayout::kBatchYXDepth,
FilterLayout::kOutputInputYX,
DataLayout::kBatchDepthYX);
return std::make_tuple(DataLayout::kBatchYXDepth, // NHWC
FilterLayout::kOutputInputYX, // NCHW
DataLayout::kBatchDepthYX // NCHW
);
}
return std::make_tuple(DataLayout::kBatchYXDepth,
FilterLayout::kOutputYXInput,
DataLayout::kBatchYXDepth);
// For other Volta f16 convolutions, use NHWC.
return kAllNHWC;
}
// Adds layout constraints on the cudnn custom-call instruction. The layout