[TF2XLA] [NFC] Simplify GenerateCompilerOptions and GetAllocator

PiperOrigin-RevId: 328975165
Change-Id: I3288abc39c04141178df98ec614ee247dd4740ec
This commit is contained in:
George Karpenkov 2020-08-28 11:22:43 -07:00 committed by TensorFlower Gardener
parent 0c92bd7381
commit 5d131e795c
4 changed files with 40 additions and 27 deletions

View File

@ -191,7 +191,9 @@ static Status CompileToLocalExecutable(
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options = GenerateCompilerOptions(
*cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
*cache, *ctx->function_library(), ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info, has_ref_vars, &tf_allocator_adapter);
std::map<int, Tensor> constant_args;
for (int i : constants) {
@ -248,8 +250,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
int device_ordinal = stream ? stream->parent()->device_ordinal()
: client->default_device_ordinal();
XlaComputationLaunchContext launch_context(
@ -472,8 +476,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
XlaExecutableClosureStore::Global()->Consume(key);
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
int device_ordinal = stream ? stream->parent()->device_ordinal()

View File

@ -49,8 +49,10 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
XlaComputationLaunchContext launch_context(
client, allocator, client->default_device_ordinal(),
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
@ -162,9 +164,11 @@ Status XlaCompileOnDemandOp::Compile(
}));
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options =
GenerateCompilerOptions(**cache, ctx, platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
XlaCompiler::Options options = GenerateCompilerOptions(
**cache, *ctx->function_library(), ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;

View File

@ -110,41 +110,40 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
DeviceBase* device, se::Stream* stream,
const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!ctx->op_device_context()) {
if (!stream) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
}
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
ctx->op_device_context()->stream());
tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
return &tf_allocator_adapter->value();
}
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
const XlaCompilationCache& cache,
const FunctionLibraryRuntime& function_library, DeviceBase* device,
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
CHECK(ctx->function_library());
XlaCompiler::Options options;
options.client = static_cast<xla::LocalClient*>(cache.client());
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
if (stream != nullptr) {
options.device_ordinal = stream->parent()->device_ordinal();
}
options.device_type = cache.device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.flib_def = function_library.GetFunctionLibraryDefinition();
options.graph_def_version = function_library.graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(tf_allocator_adapter, ctx, platform_info);
GetAllocator(tf_allocator_adapter, device, stream, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();

View File

@ -92,15 +92,19 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
//
// This is necessary because for XLA devices the underlying TF allocator returns
// dummy tensors.
//
// `stream` parameter is nullable when running on host.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info);
DeviceBase* device, se::Stream* stream,
const XlaPlatformInfo& platform_info);
// Returns created options for the XLA compiler, and writes the used allocator
// into `tf_allocator_adapter`.
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
const XlaCompilationCache& cache,
const FunctionLibraryRuntime& function_library, DeviceBase* device,
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
} // namespace tensorflow