Pass the device ordinal to use for execution to the XLA compiler for
auto-tuning. Previously, when compiling a graph for multiple devices concurrently, XLA would use the default device for auto-tuning. With this patch tf_cnn_benchmark with model resnet50 finishes on 8 V100s batch 128, and gets a speedup of ~20% over a single one; the next steps are to get it to run at batch 256 and to scale well. PiperOrigin-RevId: 206720140
This commit is contained in:
parent
78f58629ae
commit
3a1df26a25
@ -153,6 +153,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
|
||||
XlaCompiler::Options options;
|
||||
options.client = client;
|
||||
if (ctx->op_device_context() != nullptr) {
|
||||
options.device_ordinal =
|
||||
ctx->op_device_context()->stream()->parent()->device_ordinal();
|
||||
}
|
||||
options.device_type = cache->device_type();
|
||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
|
@ -209,7 +209,9 @@ Status XlaCompilationCache::BuildExecutable(
|
||||
argument_layouts[i] = &result.xla_input_shapes[i];
|
||||
}
|
||||
xla::ExecutableBuildOptions build_options;
|
||||
build_options.set_device_ordinal(client_->default_device_ordinal());
|
||||
build_options.set_device_ordinal(options.device_ordinal != -1
|
||||
? options.device_ordinal
|
||||
: client_->default_device_ordinal());
|
||||
build_options.set_result_layout(result.xla_output_shape);
|
||||
build_options.set_device_allocator(options.device_allocator);
|
||||
|
||||
|
@ -252,6 +252,12 @@ class XlaCompiler {
|
||||
// The default empty value is invalid.
|
||||
DeviceType device_type = DeviceType("");
|
||||
|
||||
// The device to use during compilation to execute instructions on, for
|
||||
// example for auto-tuning.
|
||||
// Valid values are defined by `xla::Backend::devices_ordinal_supported()`.
|
||||
// -1 indicates the default device should be used.
|
||||
int device_ordinal = -1;
|
||||
|
||||
xla::Client* client = nullptr;
|
||||
|
||||
// Function library in which to find function definitions. Must be non-null.
|
||||
|
Loading…
Reference in New Issue
Block a user