Pass the device ordinal to use for execution to the XLA compiler for
auto-tuning. Previously, when compiling a graph for multiple devices concurrently, XLA would use the default device for auto-tuning. With this patch tf_cnn_benchmark with model resnet50 finishes on 8 V100s batch 128, and gets a speedup of ~20% over a single one; the next steps are to get it to run at batch 256 and to scale well. PiperOrigin-RevId: 206720140
This commit is contained in:
parent
78f58629ae
commit
3a1df26a25
@ -153,6 +153,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
XlaCompiler::Options options;
|
XlaCompiler::Options options;
|
||||||
options.client = client;
|
options.client = client;
|
||||||
|
if (ctx->op_device_context() != nullptr) {
|
||||||
|
options.device_ordinal =
|
||||||
|
ctx->op_device_context()->stream()->parent()->device_ordinal();
|
||||||
|
}
|
||||||
options.device_type = cache->device_type();
|
options.device_type = cache->device_type();
|
||||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||||
|
@ -209,7 +209,9 @@ Status XlaCompilationCache::BuildExecutable(
|
|||||||
argument_layouts[i] = &result.xla_input_shapes[i];
|
argument_layouts[i] = &result.xla_input_shapes[i];
|
||||||
}
|
}
|
||||||
xla::ExecutableBuildOptions build_options;
|
xla::ExecutableBuildOptions build_options;
|
||||||
build_options.set_device_ordinal(client_->default_device_ordinal());
|
build_options.set_device_ordinal(options.device_ordinal != -1
|
||||||
|
? options.device_ordinal
|
||||||
|
: client_->default_device_ordinal());
|
||||||
build_options.set_result_layout(result.xla_output_shape);
|
build_options.set_result_layout(result.xla_output_shape);
|
||||||
build_options.set_device_allocator(options.device_allocator);
|
build_options.set_device_allocator(options.device_allocator);
|
||||||
|
|
||||||
|
@ -252,6 +252,12 @@ class XlaCompiler {
|
|||||||
// The default empty value is invalid.
|
// The default empty value is invalid.
|
||||||
DeviceType device_type = DeviceType("");
|
DeviceType device_type = DeviceType("");
|
||||||
|
|
||||||
|
// The device to use during compilation to execute instructions on, for
|
||||||
|
// example for auto-tuning.
|
||||||
|
// Valid values are defined by `xla::Backend::devices_ordinal_supported()`.
|
||||||
|
// -1 indicates the default device should be used.
|
||||||
|
int device_ordinal = -1;
|
||||||
|
|
||||||
xla::Client* client = nullptr;
|
xla::Client* client = nullptr;
|
||||||
|
|
||||||
// Function library in which to find function definitions. Must be non-null.
|
// Function library in which to find function definitions. Must be non-null.
|
||||||
|
Loading…
Reference in New Issue
Block a user