Pass the device ordinal to use for execution to the XLA compiler for

auto-tuning.

Previously, when compiling a graph for multiple devices concurrently, XLA
would use the default device for auto-tuning. With this patch tf_cnn_benchmark
with model resnet50 finishes on 8 V100s batch 128, and gets a speedup of ~20%
over a single one; the next steps are to get it to run at batch 256 and to
scale well.

PiperOrigin-RevId: 206720140
This commit is contained in:
A. Unique TensorFlower 2018-07-31 01:17:48 -07:00 committed by TensorFlower Gardener
parent 78f58629ae
commit 3a1df26a25
3 changed files with 13 additions and 1 deletions

View File

@ -153,6 +153,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
XlaCompiler::Options options;
options.client = client;
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
}
options.device_type = cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();

View File

@ -209,7 +209,9 @@ Status XlaCompilationCache::BuildExecutable(
argument_layouts[i] = &result.xla_input_shapes[i];
}
xla::ExecutableBuildOptions build_options;
build_options.set_device_ordinal(client_->default_device_ordinal());
build_options.set_device_ordinal(options.device_ordinal != -1
? options.device_ordinal
: client_->default_device_ordinal());
build_options.set_result_layout(result.xla_output_shape);
build_options.set_device_allocator(options.device_allocator);

View File

@ -252,6 +252,12 @@ class XlaCompiler {
// The default empty value is invalid.
DeviceType device_type = DeviceType("");
// The device to use during compilation to execute instructions on, for
// example for auto-tuning.
// Valid values are defined by `xla::Backend::devices_ordinal_supported()`.
// -1 indicates the default device should be used.
int device_ordinal = -1;
xla::Client* client = nullptr;
// Function library in which to find function definitions. Must be non-null.