[TF2XLA] [NFC] Allow using XlaCompileOnDemandOp without XlaDeviceMetadata

PiperOrigin-RevId: 325910208
Change-Id: I24f6b14fa24c614b0994ee2efdd077e5ef2fe55e
This commit is contained in:
George Karpenkov 2020-08-10 16:17:30 -07:00 committed by TensorFlower Gardener
parent 14f6926300
commit 53f2447911
7 changed files with 315 additions and 227 deletions

View File

@ -195,6 +195,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/core/kernels/data:optional_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:tf_allocator_adapter",
"//tensorflow/stream_executor/platform",
]
@ -205,16 +206,18 @@ cc_library(
"xla_device.cc",
"xla_device_context.cc",
"xla_device_ops.cc",
"xla_platform_info.cc",
],
hdrs = [
"xla_compile_on_demand_op.h",
"xla_device.h",
"xla_device_context.h",
"xla_device_ops.h",
"xla_platform_info.h",
],
# Public visibility is needed for external TF/XLA backends.
visibility = ["//visibility:public"],
deps = XLA_DEVICE_DEPS,
deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"],
)
cc_library(

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -63,38 +64,6 @@ namespace tensorflow {
namespace {
XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
platform_id = se::host::kHostPlatformId;
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
platform_id = ctx->device()
->tensorflow_gpu_device_info()
->stream->parent()
->platform()
->id();
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a
// Status and StreamExecutor does not.
//
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
// allocator to allocate real buffers.
platform_id = xla_device_metadata->platform()->id();
custom_allocator =
xla_device_metadata->client()->backend().memory_allocator();
}
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
custom_allocator);
}
// A closure describing how to run a compiled version of a TensorFlow function.
//
@ -178,31 +147,6 @@ class XlaExecutableClosureStore {
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
};
// Return allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.
//
// This is necessary because for XLA devices the underlying TF allocator returns
// dummy tensors.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!ctx->op_device_context()) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
}
// platform_info.
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
ctx->op_device_context()->stream());
return &tf_allocator_adapter->value();
}
} // namespace
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
@ -214,65 +158,9 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
constants_(constants),
resources_(resources),
function_(function),
platform_info_(PlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromContext(ctx)),
has_ref_vars_(has_ref_vars) {}
static Status BuildCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache) {
if (platform_info.xla_device_metadata()) {
*cache = new XlaCompilationCache(
platform_info.xla_device_metadata()->client(),
platform_info.xla_device_metadata()->jit_device_type());
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
if (!platform.ok()) {
return platform.status();
}
xla::StatusOr<xla::Compiler*> compiler_for_platform =
xla::Compiler::GetForPlatform(platform.ValueOrDie());
if (!compiler_for_platform.ok()) {
// In some rare cases (usually in unit tests with very small clusters) we
// may end up transforming an XLA cluster with at least one GPU operation
// (which would normally force the cluster to be compiled using XLA:GPU)
// into an XLA cluster with no GPU operations (i.e. containing only CPU
// operations). Such a cluster can fail compilation (in way that
// MarkForCompilation could not have detected) if the CPU JIT is not linked
// in.
//
// So bail out of _XlaCompile in this case, and let the executor handle the
// situation for us.
const Status& status = compiler_for_platform.status();
if (status.code() == error::NOT_FOUND) {
return errors::Unimplemented("Could not find compiler for platform ",
platform.ValueOrDie()->Name(), ": ",
status.ToString());
}
}
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
}
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
&registration)) {
return errors::InvalidArgument("No JIT device registered for ",
platform_info.device_type().type());
}
*cache = new XlaCompilationCache(
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
return Status::OK();
}
static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info,
@ -292,7 +180,7 @@ static Status CompileToLocalExecutable(
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache) {
return BuildCompilationCache(ctx, platform_info, cache);
return BuildXlaCompilationCache(ctx, platform_info, cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
@ -302,32 +190,14 @@ static Status CompileToLocalExecutable(
*client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options;
options.client = *client;
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
}
options.device_type = cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();
}
// If reference variables are not present in the graph, we can safely alias
// passthrough parameters without performing a copy.
options.alias_passthrough_params =
!has_ref_vars && !platform_info.is_on_xla_device();
XlaCompiler::Options options = GenerateCompilerOptions(
cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
std::map<int, Tensor> constant_args;
for (int i : constants) {
constant_args.insert({i, ctx->input(i)});
}
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
// Optimization: where possible, have the computation return a naked array
@ -503,7 +373,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
constants_(ConstantsVector(ctx)),
resources_(ResourcesVector(ctx)),
function_(FunctionAttr(ctx)),
platform_info_(PlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromContext(ctx)),
must_compile_(MustCompileAttr(ctx)),
has_ref_vars_(HasRefVars(ctx)) {}
@ -591,7 +461,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
}
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
void XlaRunOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaRunOp " << def().name();

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -31,61 +32,6 @@ limitations under the License.
namespace tensorflow {
// Holds some information about the platform on which an
// XlaLaunch/_XlaCompile/_XlaRun op must run on.
class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
device_allocator_(device_allocator) {}
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
bool UseMultipleStreams() const {
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
}
// Non-null only when run on an XLA device.
se::DeviceMemoryAllocator* custom_allocator() const {
return device_allocator_;
}
DeviceType device_type() const { return device_type_; }
// This is equal to xla_device_metadata()->platform()->id() if
// xla_device_metadata() is not nullptr.
se::Platform::Id platform_id() const { return platform_id_; }
// This may be null if the op this XlaPlatformInfo is for was not placed on an
// XLA device.
const XlaDevice::Metadata* xla_device_metadata() const {
return xla_device_metadata_;
}
bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
private:
DeviceType device_type_;
se::Platform::Id platform_id_;
// xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
// XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
// XlaLaunch/_XlaCompile/_XlaRun OpKernel.
const XlaDevice::Metadata* xla_device_metadata_;
// If the op associated with this XlaPlatformInfo is placed on an XLA device
// then device_allocator_ is the xla::Backend's memory allocator. If the op
// is placed on a regular CPU or GPU device then device_allocator_ is null.
se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
};
// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
// The only difference is that it does not require arguments to follow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -41,18 +42,19 @@ static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
}
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaDevice::Metadata& metadata,
XlaCompilationCache* cache,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args) {
xla::LocalClient* client = metadata.client();
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
client, client->backend().memory_allocator(),
client->default_device_ordinal(),
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
platform_info_.xla_device_metadata()
? platform_info_.xla_device_metadata()->UseMultipleStreams()
: false);
std::map<int, const Tensor*> snapshot_ptrs;
for (auto& p : variable_args) {
@ -70,7 +72,6 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
TF_RET_CHECK(stream);
VLOG(2) << "Executing computation: " << name();
xla::ExecutableRunOptions run_options;
@ -116,9 +117,9 @@ Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
}
Status XlaCompileOnDemandOp::Compile(
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult** result,
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
OpKernelContext* ctx, const XlaCompiler::CompilationResult** result,
XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable) {
std::map<int, Tensor> constant_arguments;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& device_tensor = ctx->input(i);
@ -168,24 +169,16 @@ Status XlaCompileOnDemandOp::Compile(
ResourceMgr* rm = ctx->resource_manager();
CHECK(rm);
XlaCompilationCache* cache;
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache) {
*cache = new XlaCompilationCache(metadata.client(),
metadata.jit_device_type());
return Status::OK();
rm->default_container(), "xla_cache", cache,
[&](XlaCompilationCache** write_into_cache) {
return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
XlaCompiler::Options options;
options.device_type = metadata.jit_device_type();
options.client = metadata.client();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.shape_representation_fn = metadata.shape_representation_fn();
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options =
GenerateCompilerOptions(*cache, ctx, platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
@ -206,19 +199,23 @@ Status XlaCompileOnDemandOp::Compile(
constant_arguments, variable_infos, ctx, &args));
}
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
return (*cache)->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
}
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
const XlaCompiler::CompilationResult* result;
xla::LocalExecutable* executable;
const XlaDevice::Metadata* metadata;
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
ResourceVarsSnapshot variable_args;
XlaCompilationCache* cache;
OP_REQUIRES_OK(ctx,
Compile(ctx, *metadata, &result, &variable_args, &executable));
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
Compile(ctx, &result, &cache, &variable_args, &executable));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
OP_REQUIRES_OK(ctx, Run(ctx, cache, result, executable, variable_args));
}
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/function.h"
@ -35,7 +36,8 @@ namespace tensorflow {
// vanilla TensorFlow op as long as the bridge supports it.
class XlaCompileOnDemandOp : public OpKernel {
public:
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
void Compute(OpKernelContext* ctx) override;
private:
@ -46,14 +48,18 @@ class XlaCompileOnDemandOp : public OpKernel {
Status MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx,
FunctionLibraryRuntime* flib_runtime,
bool* result);
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
Status Compile(OpKernelContext* ctx,
const XlaCompiler::CompilationResult** result,
XlaCompilationCache** cache,
ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable);
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
Status Run(OpKernelContext* ctx, XlaCompilationCache* cache,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args);
const XlaPlatformInfo platform_info_;
};
} // namespace tensorflow

View File

@ -0,0 +1,158 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/xla/client/client_library.h"
namespace tensorflow {
Status BuildXlaCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache) {
if (platform_info.xla_device_metadata()) {
*cache = new XlaCompilationCache(
platform_info.xla_device_metadata()->client(),
platform_info.xla_device_metadata()->jit_device_type());
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
if (!platform.ok()) {
return platform.status();
}
xla::StatusOr<xla::Compiler*> compiler_for_platform =
xla::Compiler::GetForPlatform(platform.ValueOrDie());
if (!compiler_for_platform.ok()) {
// In some rare cases (usually in unit tests with very small clusters) we
// may end up transforming an XLA cluster with at least one GPU operation
// (which would normally force the cluster to be compiled using XLA:GPU)
// into an XLA cluster with no GPU operations (i.e. containing only CPU
// operations). Such a cluster can fail compilation (in way that
// MarkForCompilation could not have detected) if the CPU JIT is not linked
// in.
//
// So bail out of _XlaCompile in this case, and let the executor handle the
// situation for us.
const Status& status = compiler_for_platform.status();
if (status.code() == error::NOT_FOUND) {
return errors::Unimplemented("Could not find compiler for platform ",
platform.ValueOrDie()->Name(), ": ",
status.ToString());
}
}
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
}
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
&registration)) {
return errors::InvalidArgument("No JIT device registered for ",
platform_info.device_type().type());
}
*cache = new XlaCompilationCache(
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
return Status::OK();
}
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
platform_id = se::host::kHostPlatformId;
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
platform_id = ctx->device()
->tensorflow_gpu_device_info()
->stream->parent()
->platform()
->id();
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a
// Status and StreamExecutor does not.
//
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
// allocator to allocate real buffers.
platform_id = xla_device_metadata->platform()->id();
custom_allocator =
xla_device_metadata->client()->backend().memory_allocator();
}
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
custom_allocator);
}
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!ctx->op_device_context()) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
}
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
ctx->op_device_context()->stream());
return &tf_allocator_adapter->value();
}
XlaCompiler::Options GenerateCompilerOptions(
XlaCompilationCache* cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
XlaCompiler::Options options;
options.client = static_cast<xla::LocalClient*>(cache->client());
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
}
options.device_type = cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(tf_allocator_adapter, ctx, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();
}
// If reference variables are not present in the graph, we can safely alias
// passthrough parameters without performing a copy.
options.alias_passthrough_params =
!has_ref_vars && !platform_info.is_on_xla_device();
return options;
}
} // namespace tensorflow

View File

@ -0,0 +1,108 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
#define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
namespace tensorflow {
// Holds some information about the platform on which an
// XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of
// abstraction for normal and XLA devices.
class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
device_allocator_(device_allocator) {}
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
bool UseMultipleStreams() const {
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
}
// Non-null only when run on an XLA device.
se::DeviceMemoryAllocator* custom_allocator() const {
return device_allocator_;
}
DeviceType device_type() const { return device_type_; }
// This is equal to xla_device_metadata()->platform()->id() if
// xla_device_metadata() is not nullptr.
se::Platform::Id platform_id() const { return platform_id_; }
// This may be null if the op this XlaPlatformInfo is for was not placed on an
// XLA device.
const XlaDevice::Metadata* xla_device_metadata() const {
return xla_device_metadata_;
}
bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
private:
DeviceType device_type_;
se::Platform::Id platform_id_;
// xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
// XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
// XlaLaunch/_XlaCompile/_XlaRun OpKernel.
const XlaDevice::Metadata* xla_device_metadata_;
// If the op associated with this XlaPlatformInfo is placed on an XLA device
// then device_allocator_ is the xla::Backend's memory allocator. If the op
// is placed on a regular CPU or GPU device then device_allocator_ is null.
se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
};
// Returns created XLA compilation cache.
Status BuildXlaCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache);
// Returns information about the platform from kernel context.
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
// Returns allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.
//
// This is necessary because for XLA devices the underlying TF allocator returns
// dummy tensors.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info);
// Returns created options for the XLA compiler, and writes the used allocator
// into `tf_allocator_adapter`.
XlaCompiler::Options GenerateCompilerOptions(
XlaCompilationCache* cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_