STT-tensorflow/tensorflow/compiler/jit/xla_platform_info.cc
George Karpenkov 5d131e795c [TF2XLA] [NFC] Simplify GenerateCompilerOptions and GetAllocator
PiperOrigin-RevId: 328975165
Change-Id: I3288abc39c04141178df98ec614ee247dd4740ec
2020-08-28 11:34:15 -07:00

159 lines
6.7 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/xla/client/client_library.h"
namespace tensorflow {
Status BuildXlaCompilationCache(DeviceBase* device,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache) {
if (platform_info.xla_device_metadata()) {
*cache = new XlaCompilationCache(
platform_info.xla_device_metadata()->client(),
platform_info.xla_device_metadata()->jit_device_type());
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
if (!platform.ok()) {
return platform.status();
}
xla::StatusOr<xla::Compiler*> compiler_for_platform =
xla::Compiler::GetForPlatform(platform.ValueOrDie());
if (!compiler_for_platform.ok()) {
// In some rare cases (usually in unit tests with very small clusters) we
// may end up transforming an XLA cluster with at least one GPU operation
// (which would normally force the cluster to be compiled using XLA:GPU)
// into an XLA cluster with no GPU operations (i.e. containing only CPU
// operations). Such a cluster can fail compilation (in way that
// MarkForCompilation could not have detected) if the CPU JIT is not linked
// in.
//
// So bail out of _XlaCompile in this case, and let the executor handle the
// situation for us.
const Status& status = compiler_for_platform.status();
if (status.code() == error::NOT_FOUND) {
return errors::Unimplemented("Could not find compiler for platform ",
platform.ValueOrDie()->Name(), ": ",
status.ToString());
}
}
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
device->tensorflow_cpu_worker_threads()->num_threads);
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
}
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
&registration)) {
return errors::InvalidArgument("No JIT device registered for ",
platform_info.device_type().type());
}
*cache = new XlaCompilationCache(
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
return Status::OK();
}
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
auto device = static_cast<Device*>(device_base);
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (device->device_type() == DEVICE_CPU) {
platform_id = se::host::kHostPlatformId;
} else if (device->device_type() == DEVICE_GPU) {
platform_id = device->tensorflow_gpu_device_info()
->stream->parent()
->platform()
->id();
} else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
.ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a
// Status and StreamExecutor does not.
//
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
// allocator to allocate real buffers.
platform_id = xla_device_metadata->platform()->id();
custom_allocator =
xla_device_metadata->client()->backend().memory_allocator();
}
return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
xla_device_metadata, custom_allocator);
}
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
DeviceBase* device, se::Stream* stream,
const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!stream) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
}
tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
return &tf_allocator_adapter->value();
}
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache,
const FunctionLibraryRuntime& function_library, DeviceBase* device,
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
XlaCompiler::Options options;
options.client = static_cast<xla::LocalClient*>(cache.client());
if (stream != nullptr) {
options.device_ordinal = stream->parent()->device_ordinal();
}
options.device_type = cache.device_type();
options.flib_def = function_library.GetFunctionLibraryDefinition();
options.graph_def_version = function_library.graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(tf_allocator_adapter, device, stream, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();
}
// If reference variables are not present in the graph, we can safely alias
// passthrough parameters without performing a copy.
options.alias_passthrough_params =
!has_ref_vars && !platform_info.is_on_xla_device();
return options;
}
} // namespace tensorflow