Add an option to override XLA GPU conv layouts to NHWC
This is just for testing. PiperOrigin-RevId: 342741577 Change-Id: I8cc9da1efa562f9bd9966f8e3390cf5f9d352d3a
This commit is contained in:
parent
d652afc0d8
commit
d259d86efc
@ -540,7 +540,12 @@ static void AllocateFlags() {
|
||||
"xla_gpu_force_conv_nchw",
|
||||
bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw),
|
||||
flag_values->xla_gpu_force_conv_nchw(),
|
||||
"For cuDNN convolutions, always NCHW layouts."));
|
||||
"For cuDNN convolutions, always use NCHW layouts."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_gpu_force_conv_nhwc",
|
||||
bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nhwc),
|
||||
flag_values->xla_gpu_force_conv_nhwc(),
|
||||
"For cuDNN convolutions, always use NHWC layouts."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_gpu_algorithm_denylist_path",
|
||||
string_setter_for(&DebugOptions::set_xla_gpu_algorithm_denylist_path),
|
||||
|
@ -68,9 +68,15 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
|
||||
instr->GetModule()->config().debug_options();
|
||||
|
||||
if (debug_options.xla_gpu_force_conv_nchw()) {
|
||||
VLOG(2) << "Overriding layout to NCHW for " << instr->ToString();
|
||||
return kAllNCHW;
|
||||
}
|
||||
|
||||
if (debug_options.xla_gpu_force_conv_nhwc()) {
|
||||
VLOG(2) << "Overriding layout to NHWC for " << instr->ToString();
|
||||
return kAllNHWC;
|
||||
}
|
||||
|
||||
// If we're not Volta or not fp16, or not conv2D, the decision is easy: Use
|
||||
// NCHW.
|
||||
if (instr->operand(0)->shape().element_type() != xla::PrimitiveType::F16 ||
|
||||
|
@ -268,7 +268,9 @@ message DebugOptions {
|
||||
// END flags controlling dumping HLO modules.
|
||||
//
|
||||
|
||||
// Overrides for XLA GPU's convolution layout heuristic.
|
||||
bool xla_gpu_force_conv_nchw = 125;
|
||||
bool xla_gpu_force_conv_nhwc = 146;
|
||||
|
||||
// Paths to files with ptx code.
|
||||
repeated string xla_gpu_ptx_file = 127;
|
||||
@ -300,7 +302,7 @@ message DebugOptions {
|
||||
// Enable detailed logging into vlog.
|
||||
bool xla_detailed_logging = 143;
|
||||
|
||||
// Next id: 145
|
||||
// Next id: 146
|
||||
|
||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||
// interpretation of these values is left to the backend.
|
||||
|
Loading…
Reference in New Issue
Block a user