Add an option to override XLA GPU conv layouts to NHWC

This is just for testing.

PiperOrigin-RevId: 342741577
Change-Id: I8cc9da1efa562f9bd9966f8e3390cf5f9d352d3a
This commit is contained in:
Sanjoy Das 2020-11-16 16:02:53 -08:00 committed by TensorFlower Gardener
parent d652afc0d8
commit d259d86efc
3 changed files with 15 additions and 2 deletions

View File

@ -540,7 +540,12 @@ static void AllocateFlags() {
"xla_gpu_force_conv_nchw",
bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw),
flag_values->xla_gpu_force_conv_nchw(),
"For cuDNN convolutions, always NCHW layouts."));
"For cuDNN convolutions, always use NCHW layouts."));
flag_objects->push_back(tensorflow::Flag(
"xla_gpu_force_conv_nhwc",
bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nhwc),
flag_values->xla_gpu_force_conv_nhwc(),
"For cuDNN convolutions, always use NHWC layouts."));
flag_objects->push_back(tensorflow::Flag(
"xla_gpu_algorithm_denylist_path",
string_setter_for(&DebugOptions::set_xla_gpu_algorithm_denylist_path),

View File

@ -68,9 +68,15 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
instr->GetModule()->config().debug_options();
if (debug_options.xla_gpu_force_conv_nchw()) {
VLOG(2) << "Overriding layout to NCHW for " << instr->ToString();
return kAllNCHW;
}
if (debug_options.xla_gpu_force_conv_nhwc()) {
VLOG(2) << "Overriding layout to NHWC for " << instr->ToString();
return kAllNHWC;
}
// If we're not Volta or not fp16, or not conv2D, the decision is easy: Use
// NCHW.
if (instr->operand(0)->shape().element_type() != xla::PrimitiveType::F16 ||

View File

@ -268,7 +268,9 @@ message DebugOptions {
// END flags controlling dumping HLO modules.
//
// Overrides for XLA GPU's convolution layout heuristic.
bool xla_gpu_force_conv_nchw = 125;
bool xla_gpu_force_conv_nhwc = 146;
// Paths to files with ptx code.
repeated string xla_gpu_ptx_file = 127;
@ -300,7 +302,7 @@ message DebugOptions {
// Enable detailed logging into vlog.
bool xla_detailed_logging = 143;
// Next id: 145
// Next id: 146
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.