[TF2XLA] [NFC] Allow using XlaCompileOnDemandOp without XlaDeviceMetadata
PiperOrigin-RevId: 325910208 Change-Id: I24f6b14fa24c614b0994ee2efdd077e5ef2fe55e
This commit is contained in:
parent
14f6926300
commit
53f2447911
@ -195,6 +195,7 @@ XLA_DEVICE_DEPS = [
|
||||
"//tensorflow/core/kernels/data:optional_ops",
|
||||
"//tensorflow/core/kernels/data:prefetch_dataset_op",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
]
|
||||
|
||||
@ -205,16 +206,18 @@ cc_library(
|
||||
"xla_device.cc",
|
||||
"xla_device_context.cc",
|
||||
"xla_device_ops.cc",
|
||||
"xla_platform_info.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_compile_on_demand_op.h",
|
||||
"xla_device.h",
|
||||
"xla_device_context.h",
|
||||
"xla_device_ops.h",
|
||||
"xla_platform_info.h",
|
||||
],
|
||||
# Public visibility is needed for external TF/XLA backends.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = XLA_DEVICE_DEPS,
|
||||
deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -63,38 +64,6 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
DeviceType device_type = ctx->device_type();
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
se::DeviceMemoryAllocator* custom_allocator = nullptr;
|
||||
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
|
||||
platform_id = ctx->device()
|
||||
->tensorflow_gpu_device_info()
|
||||
->stream->parent()
|
||||
->platform()
|
||||
->id();
|
||||
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
|
||||
// If we are on an XlaDevice, use the underlying XLA platform's allocator
|
||||
// directly. We could use the StreamExecutor's allocator which may
|
||||
// theoretically be more correct, but XLA returns a nice OOM message in a
|
||||
// Status and StreamExecutor does not.
|
||||
//
|
||||
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
|
||||
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
|
||||
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
|
||||
// allocator to allocate real buffers.
|
||||
platform_id = xla_device_metadata->platform()->id();
|
||||
custom_allocator =
|
||||
xla_device_metadata->client()->backend().memory_allocator();
|
||||
}
|
||||
|
||||
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
custom_allocator);
|
||||
}
|
||||
|
||||
// A closure describing how to run a compiled version of a TensorFlow function.
|
||||
//
|
||||
@ -178,31 +147,6 @@ class XlaExecutableClosureStore {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
|
||||
};
|
||||
|
||||
// Return allocator from platform info if non-null, or populate and return a
|
||||
// pointer to the allocator adapter with allocator from context.
|
||||
//
|
||||
// This is necessary because for XLA devices the underlying TF allocator returns
|
||||
// dummy tensors.
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
|
||||
if (platform_info.custom_allocator()) {
|
||||
return platform_info.custom_allocator();
|
||||
}
|
||||
if (!ctx->op_device_context()) {
|
||||
// Stream is not set for the host platform.
|
||||
se::Platform* platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
|
||||
.ValueOrDie();
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
// platform_info.
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
|
||||
ctx->op_device_context()->stream());
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
@ -214,65 +158,9 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
constants_(constants),
|
||||
resources_(resources),
|
||||
function_(function),
|
||||
platform_info_(PlatformInfoFromContext(ctx)),
|
||||
platform_info_(XlaPlatformInfoFromContext(ctx)),
|
||||
has_ref_vars_(has_ref_vars) {}
|
||||
|
||||
static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
XlaCompilationCache** cache) {
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
*cache = new XlaCompilationCache(
|
||||
platform_info.xla_device_metadata()->client(),
|
||||
platform_info.xla_device_metadata()->jit_device_type());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
|
||||
if (!platform.ok()) {
|
||||
return platform.status();
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::Compiler*> compiler_for_platform =
|
||||
xla::Compiler::GetForPlatform(platform.ValueOrDie());
|
||||
if (!compiler_for_platform.ok()) {
|
||||
// In some rare cases (usually in unit tests with very small clusters) we
|
||||
// may end up transforming an XLA cluster with at least one GPU operation
|
||||
// (which would normally force the cluster to be compiled using XLA:GPU)
|
||||
// into an XLA cluster with no GPU operations (i.e. containing only CPU
|
||||
// operations). Such a cluster can fail compilation (in way that
|
||||
// MarkForCompilation could not have detected) if the CPU JIT is not linked
|
||||
// in.
|
||||
//
|
||||
// So bail out of _XlaCompile in this case, and let the executor handle the
|
||||
// situation for us.
|
||||
const Status& status = compiler_for_platform.status();
|
||||
if (status.code() == error::NOT_FOUND) {
|
||||
return errors::Unimplemented("Could not find compiler for platform ",
|
||||
platform.ValueOrDie()->Name(), ": ",
|
||||
status.ToString());
|
||||
}
|
||||
}
|
||||
|
||||
xla::LocalClientOptions client_options;
|
||||
client_options.set_platform(platform.ValueOrDie());
|
||||
client_options.set_intra_op_parallelism_threads(
|
||||
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
|
||||
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
|
||||
if (!client.ok()) {
|
||||
return client.status();
|
||||
}
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
|
||||
®istration)) {
|
||||
return errors::InvalidArgument("No JIT device registered for ",
|
||||
platform_info.device_type().type());
|
||||
}
|
||||
*cache = new XlaCompilationCache(
|
||||
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status CompileToLocalExecutable(
|
||||
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
@ -292,7 +180,7 @@ static Status CompileToLocalExecutable(
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
|
||||
rm->default_container(), "xla_cache", &cache,
|
||||
[&](XlaCompilationCache** cache) {
|
||||
return BuildCompilationCache(ctx, platform_info, cache);
|
||||
return BuildXlaCompilationCache(ctx, platform_info, cache);
|
||||
}));
|
||||
// Hold the reference to the JIT during evaluation. (We could probably
|
||||
// free it sooner because the ResourceMgr will retain a reference, but
|
||||
@ -302,32 +190,14 @@ static Status CompileToLocalExecutable(
|
||||
*client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options;
|
||||
options.client = *client;
|
||||
if (ctx->op_device_context() != nullptr) {
|
||||
options.device_ordinal =
|
||||
ctx->op_device_context()->stream()->parent()->device_ordinal();
|
||||
}
|
||||
options.device_type = cache->device_type();
|
||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
options.allow_cpu_custom_calls =
|
||||
(platform_info.platform_id() == se::host::kHostPlatformId);
|
||||
options.device_allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info);
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
options.shape_representation_fn =
|
||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||
}
|
||||
// If reference variables are not present in the graph, we can safely alias
|
||||
// passthrough parameters without performing a copy.
|
||||
options.alias_passthrough_params =
|
||||
!has_ref_vars && !platform_info.is_on_xla_device();
|
||||
XlaCompiler::Options options = GenerateCompilerOptions(
|
||||
cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||
|
||||
std::map<int, Tensor> constant_args;
|
||||
for (int i : constants) {
|
||||
constant_args.insert({i, ctx->input(i)});
|
||||
}
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true;
|
||||
// Optimization: where possible, have the computation return a naked array
|
||||
@ -503,7 +373,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
|
||||
constants_(ConstantsVector(ctx)),
|
||||
resources_(ResourcesVector(ctx)),
|
||||
function_(FunctionAttr(ctx)),
|
||||
platform_info_(PlatformInfoFromContext(ctx)),
|
||||
platform_info_(XlaPlatformInfoFromContext(ctx)),
|
||||
must_compile_(MustCompileAttr(ctx)),
|
||||
has_ref_vars_(HasRefVars(ctx)) {}
|
||||
|
||||
@ -591,7 +461,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
|
||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
|
||||
|
||||
void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaRunOp " << def().name();
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -31,61 +32,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Holds some information about the platform on which an
|
||||
// XlaLaunch/_XlaCompile/_XlaRun op must run on.
|
||||
class XlaPlatformInfo {
|
||||
public:
|
||||
XlaPlatformInfo() : device_type_("") {}
|
||||
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
||||
explicit XlaPlatformInfo(const DeviceType device_type,
|
||||
se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
: device_type_(device_type),
|
||||
platform_id_(platform_id),
|
||||
xla_device_metadata_(xla_device_metadata),
|
||||
device_allocator_(device_allocator) {}
|
||||
|
||||
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
|
||||
|
||||
bool UseMultipleStreams() const {
|
||||
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
|
||||
}
|
||||
|
||||
// Non-null only when run on an XLA device.
|
||||
se::DeviceMemoryAllocator* custom_allocator() const {
|
||||
return device_allocator_;
|
||||
}
|
||||
|
||||
DeviceType device_type() const { return device_type_; }
|
||||
|
||||
// This is equal to xla_device_metadata()->platform()->id() if
|
||||
// xla_device_metadata() is not nullptr.
|
||||
se::Platform::Id platform_id() const { return platform_id_; }
|
||||
|
||||
// This may be null if the op this XlaPlatformInfo is for was not placed on an
|
||||
// XLA device.
|
||||
const XlaDevice::Metadata* xla_device_metadata() const {
|
||||
return xla_device_metadata_;
|
||||
}
|
||||
bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
|
||||
|
||||
private:
|
||||
DeviceType device_type_;
|
||||
se::Platform::Id platform_id_;
|
||||
|
||||
// xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
|
||||
// XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
|
||||
// XlaLaunch/_XlaCompile/_XlaRun OpKernel.
|
||||
const XlaDevice::Metadata* xla_device_metadata_;
|
||||
|
||||
// If the op associated with this XlaPlatformInfo is placed on an XLA device
|
||||
// then device_allocator_ is the xla::Backend's memory allocator. If the op
|
||||
// is placed on a regular CPU or GPU device then device_allocator_ is null.
|
||||
se::DeviceMemoryAllocator* device_allocator_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
||||
};
|
||||
|
||||
// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
|
||||
// The only difference is that it does not require arguments to follow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -41,18 +42,19 @@ static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
const XlaDevice::Metadata& metadata,
|
||||
XlaCompilationCache* cache,
|
||||
const XlaCompiler::CompilationResult* result,
|
||||
xla::LocalExecutable* executable,
|
||||
const ResourceVarsSnapshot& variable_args) {
|
||||
xla::LocalClient* client = metadata.client();
|
||||
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
// Builds an XLA allocator for the device.
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, client->backend().memory_allocator(),
|
||||
client->default_device_ordinal(),
|
||||
/*allocate_xla_tensors=*/true,
|
||||
/*use_multiple_streams=*/metadata.UseMultipleStreams());
|
||||
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
|
||||
platform_info_.xla_device_metadata()
|
||||
? platform_info_.xla_device_metadata()->UseMultipleStreams()
|
||||
: false);
|
||||
|
||||
std::map<int, const Tensor*> snapshot_ptrs;
|
||||
for (auto& p : variable_args) {
|
||||
@ -70,7 +72,6 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
TF_RET_CHECK(stream);
|
||||
|
||||
VLOG(2) << "Executing computation: " << name();
|
||||
xla::ExecutableRunOptions run_options;
|
||||
@ -116,9 +117,9 @@ Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
|
||||
}
|
||||
|
||||
Status XlaCompileOnDemandOp::Compile(
|
||||
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult** result,
|
||||
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult** result,
|
||||
XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args,
|
||||
xla::LocalExecutable** executable) {
|
||||
std::map<int, Tensor> constant_arguments;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& device_tensor = ctx->input(i);
|
||||
@ -168,24 +169,16 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
CHECK(rm);
|
||||
|
||||
XlaCompilationCache* cache;
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
|
||||
rm->default_container(), "xla_cache", &cache,
|
||||
[&](XlaCompilationCache** cache) {
|
||||
*cache = new XlaCompilationCache(metadata.client(),
|
||||
metadata.jit_device_type());
|
||||
return Status::OK();
|
||||
rm->default_container(), "xla_cache", cache,
|
||||
[&](XlaCompilationCache** write_into_cache) {
|
||||
return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache);
|
||||
}));
|
||||
// Hold the reference to the JIT during evaluation. (We could probably
|
||||
// free it sooner because the ResourceMgr will retain a reference, but
|
||||
// this is more obviously correct.)
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
|
||||
XlaCompiler::Options options;
|
||||
options.device_type = metadata.jit_device_type();
|
||||
options.client = metadata.client();
|
||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||
options.shape_representation_fn = metadata.shape_representation_fn();
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options =
|
||||
GenerateCompilerOptions(*cache, ctx, platform_info_,
|
||||
/*has_ref_vars=*/true, &tf_allocator_adapter);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true;
|
||||
@ -206,19 +199,23 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
constant_arguments, variable_infos, ctx, &args));
|
||||
}
|
||||
|
||||
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
|
||||
executable);
|
||||
return (*cache)->CompileSingleOp(options, args, ctx, compile_options, result,
|
||||
executable);
|
||||
}
|
||||
|
||||
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
|
||||
const XlaCompiler::CompilationResult* result;
|
||||
xla::LocalExecutable* executable;
|
||||
const XlaDevice::Metadata* metadata;
|
||||
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
|
||||
ResourceVarsSnapshot variable_args;
|
||||
XlaCompilationCache* cache;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
Compile(ctx, *metadata, &result, &variable_args, &executable));
|
||||
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
|
||||
Compile(ctx, &result, &cache, &variable_args, &executable));
|
||||
|
||||
// Hold the reference to the JIT during evaluation. (We could probably
|
||||
// free it sooner because the ResourceMgr will retain a reference, but
|
||||
// this is more obviously correct.)
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
OP_REQUIRES_OK(ctx, Run(ctx, cache, result, executable, variable_args));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -35,7 +36,8 @@ namespace tensorflow {
|
||||
// vanilla TensorFlow op as long as the bridge supports it.
|
||||
class XlaCompileOnDemandOp : public OpKernel {
|
||||
public:
|
||||
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
@ -46,14 +48,18 @@ class XlaCompileOnDemandOp : public OpKernel {
|
||||
Status MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
bool* result);
|
||||
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
Status Compile(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult** result,
|
||||
XlaCompilationCache** cache,
|
||||
ResourceVarsSnapshot* variable_args,
|
||||
xla::LocalExecutable** executable);
|
||||
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
|
||||
Status Run(OpKernelContext* ctx, XlaCompilationCache* cache,
|
||||
const XlaCompiler::CompilationResult* result,
|
||||
xla::LocalExecutable* executable,
|
||||
const ResourceVarsSnapshot& variable_args);
|
||||
|
||||
const XlaPlatformInfo platform_info_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
158
tensorflow/compiler/jit/xla_platform_info.cc
Normal file
158
tensorflow/compiler/jit/xla_platform_info.cc
Normal file
@ -0,0 +1,158 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status BuildXlaCompilationCache(OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
XlaCompilationCache** cache) {
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
*cache = new XlaCompilationCache(
|
||||
platform_info.xla_device_metadata()->client(),
|
||||
platform_info.xla_device_metadata()->jit_device_type());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
|
||||
if (!platform.ok()) {
|
||||
return platform.status();
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::Compiler*> compiler_for_platform =
|
||||
xla::Compiler::GetForPlatform(platform.ValueOrDie());
|
||||
if (!compiler_for_platform.ok()) {
|
||||
// In some rare cases (usually in unit tests with very small clusters) we
|
||||
// may end up transforming an XLA cluster with at least one GPU operation
|
||||
// (which would normally force the cluster to be compiled using XLA:GPU)
|
||||
// into an XLA cluster with no GPU operations (i.e. containing only CPU
|
||||
// operations). Such a cluster can fail compilation (in way that
|
||||
// MarkForCompilation could not have detected) if the CPU JIT is not linked
|
||||
// in.
|
||||
//
|
||||
// So bail out of _XlaCompile in this case, and let the executor handle the
|
||||
// situation for us.
|
||||
const Status& status = compiler_for_platform.status();
|
||||
if (status.code() == error::NOT_FOUND) {
|
||||
return errors::Unimplemented("Could not find compiler for platform ",
|
||||
platform.ValueOrDie()->Name(), ": ",
|
||||
status.ToString());
|
||||
}
|
||||
}
|
||||
|
||||
xla::LocalClientOptions client_options;
|
||||
client_options.set_platform(platform.ValueOrDie());
|
||||
client_options.set_intra_op_parallelism_threads(
|
||||
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
|
||||
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
|
||||
if (!client.ok()) {
|
||||
return client.status();
|
||||
}
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
|
||||
®istration)) {
|
||||
return errors::InvalidArgument("No JIT device registered for ",
|
||||
platform_info.device_type().type());
|
||||
}
|
||||
*cache = new XlaCompilationCache(
|
||||
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
DeviceType device_type = ctx->device_type();
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
se::DeviceMemoryAllocator* custom_allocator = nullptr;
|
||||
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
|
||||
platform_id = ctx->device()
|
||||
->tensorflow_gpu_device_info()
|
||||
->stream->parent()
|
||||
->platform()
|
||||
->id();
|
||||
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
|
||||
// If we are on an XlaDevice, use the underlying XLA platform's allocator
|
||||
// directly. We could use the StreamExecutor's allocator which may
|
||||
// theoretically be more correct, but XLA returns a nice OOM message in a
|
||||
// Status and StreamExecutor does not.
|
||||
//
|
||||
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
|
||||
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
|
||||
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
|
||||
// allocator to allocate real buffers.
|
||||
platform_id = xla_device_metadata->platform()->id();
|
||||
custom_allocator =
|
||||
xla_device_metadata->client()->backend().memory_allocator();
|
||||
}
|
||||
|
||||
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
custom_allocator);
|
||||
}
|
||||
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
|
||||
if (platform_info.custom_allocator()) {
|
||||
return platform_info.custom_allocator();
|
||||
}
|
||||
if (!ctx->op_device_context()) {
|
||||
// Stream is not set for the host platform.
|
||||
se::Platform* platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
|
||||
.ValueOrDie();
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
|
||||
ctx->op_device_context()->stream());
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
|
||||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
XlaCompilationCache* cache, OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
|
||||
XlaCompiler::Options options;
|
||||
options.client = static_cast<xla::LocalClient*>(cache->client());
|
||||
if (ctx->op_device_context() != nullptr) {
|
||||
options.device_ordinal =
|
||||
ctx->op_device_context()->stream()->parent()->device_ordinal();
|
||||
}
|
||||
options.device_type = cache->device_type();
|
||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
options.allow_cpu_custom_calls =
|
||||
(platform_info.platform_id() == se::host::kHostPlatformId);
|
||||
options.device_allocator =
|
||||
GetAllocator(tf_allocator_adapter, ctx, platform_info);
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
options.shape_representation_fn =
|
||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||
}
|
||||
// If reference variables are not present in the graph, we can safely alias
|
||||
// passthrough parameters without performing a copy.
|
||||
options.alias_passthrough_params =
|
||||
!has_ref_vars && !platform_info.is_on_xla_device();
|
||||
return options;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
108
tensorflow/compiler/jit/xla_platform_info.h
Normal file
108
tensorflow/compiler/jit/xla_platform_info.h
Normal file
@ -0,0 +1,108 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Holds some information about the platform on which an
|
||||
// XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of
|
||||
// abstraction for normal and XLA devices.
|
||||
class XlaPlatformInfo {
|
||||
public:
|
||||
XlaPlatformInfo() : device_type_("") {}
|
||||
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
||||
explicit XlaPlatformInfo(const DeviceType device_type,
|
||||
se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
: device_type_(device_type),
|
||||
platform_id_(platform_id),
|
||||
xla_device_metadata_(xla_device_metadata),
|
||||
device_allocator_(device_allocator) {}
|
||||
|
||||
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
|
||||
|
||||
bool UseMultipleStreams() const {
|
||||
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
|
||||
}
|
||||
|
||||
// Non-null only when run on an XLA device.
|
||||
se::DeviceMemoryAllocator* custom_allocator() const {
|
||||
return device_allocator_;
|
||||
}
|
||||
|
||||
DeviceType device_type() const { return device_type_; }
|
||||
|
||||
// This is equal to xla_device_metadata()->platform()->id() if
|
||||
// xla_device_metadata() is not nullptr.
|
||||
se::Platform::Id platform_id() const { return platform_id_; }
|
||||
|
||||
// This may be null if the op this XlaPlatformInfo is for was not placed on an
|
||||
// XLA device.
|
||||
const XlaDevice::Metadata* xla_device_metadata() const {
|
||||
return xla_device_metadata_;
|
||||
}
|
||||
bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
|
||||
|
||||
private:
|
||||
DeviceType device_type_;
|
||||
se::Platform::Id platform_id_;
|
||||
|
||||
// xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
|
||||
// XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
|
||||
// XlaLaunch/_XlaCompile/_XlaRun OpKernel.
|
||||
const XlaDevice::Metadata* xla_device_metadata_;
|
||||
|
||||
// If the op associated with this XlaPlatformInfo is placed on an XLA device
|
||||
// then device_allocator_ is the xla::Backend's memory allocator. If the op
|
||||
// is placed on a regular CPU or GPU device then device_allocator_ is null.
|
||||
se::DeviceMemoryAllocator* device_allocator_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
||||
};
|
||||
|
||||
// Returns created XLA compilation cache.
|
||||
Status BuildXlaCompilationCache(OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
XlaCompilationCache** cache);
|
||||
|
||||
// Returns information about the platform from kernel context.
|
||||
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
|
||||
|
||||
// Returns allocator from platform info if non-null, or populate and return a
|
||||
// pointer to the allocator adapter with allocator from context.
|
||||
//
|
||||
// This is necessary because for XLA devices the underlying TF allocator returns
|
||||
// dummy tensors.
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
OpKernelContext* ctx, const XlaPlatformInfo& platform_info);
|
||||
|
||||
// Returns created options for the XLA compiler, and writes the used allocator
|
||||
// into `tf_allocator_adapter`.
|
||||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
XlaCompilationCache* cache, OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
|
Loading…
x
Reference in New Issue
Block a user