[TF2XLA] [NFC] Simplify GenerateCompilerOptions and GetAllocator
PiperOrigin-RevId: 328975165 Change-Id: I3288abc39c04141178df98ec614ee247dd4740ec
This commit is contained in:
parent
0c92bd7381
commit
5d131e795c
@ -191,7 +191,9 @@ static Status CompileToLocalExecutable(
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options = GenerateCompilerOptions(
|
||||
*cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||
*cache, *ctx->function_library(), ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||
|
||||
std::map<int, Tensor> constant_args;
|
||||
for (int i : constants) {
|
||||
@ -248,8 +250,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_);
|
||||
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||
: client->default_device_ordinal();
|
||||
XlaComputationLaunchContext launch_context(
|
||||
@ -472,8 +476,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
XlaExecutableClosureStore::Global()->Consume(key);
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_);
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||
|
@ -49,8 +49,10 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_);
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, allocator, client->default_device_ordinal(),
|
||||
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
|
||||
@ -162,9 +164,11 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
}));
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options =
|
||||
GenerateCompilerOptions(**cache, ctx, platform_info_,
|
||||
/*has_ref_vars=*/true, &tf_allocator_adapter);
|
||||
XlaCompiler::Options options = GenerateCompilerOptions(
|
||||
**cache, *ctx->function_library(), ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_,
|
||||
/*has_ref_vars=*/true, &tf_allocator_adapter);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true;
|
||||
|
@ -110,41 +110,40 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
|
||||
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
|
||||
DeviceBase* device, se::Stream* stream,
|
||||
const XlaPlatformInfo& platform_info) {
|
||||
if (platform_info.custom_allocator()) {
|
||||
return platform_info.custom_allocator();
|
||||
}
|
||||
if (!ctx->op_device_context()) {
|
||||
if (!stream) {
|
||||
// Stream is not set for the host platform.
|
||||
se::Platform* platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
|
||||
.ValueOrDie();
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
|
||||
tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
|
||||
ctx->op_device_context()->stream());
|
||||
tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
|
||||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
const XlaCompilationCache& cache, OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
const XlaCompilationCache& cache,
|
||||
const FunctionLibraryRuntime& function_library, DeviceBase* device,
|
||||
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
|
||||
CHECK(ctx->function_library());
|
||||
XlaCompiler::Options options;
|
||||
options.client = static_cast<xla::LocalClient*>(cache.client());
|
||||
if (ctx->op_device_context() != nullptr) {
|
||||
options.device_ordinal =
|
||||
ctx->op_device_context()->stream()->parent()->device_ordinal();
|
||||
if (stream != nullptr) {
|
||||
options.device_ordinal = stream->parent()->device_ordinal();
|
||||
}
|
||||
options.device_type = cache.device_type();
|
||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
options.flib_def = function_library.GetFunctionLibraryDefinition();
|
||||
options.graph_def_version = function_library.graph_def_version();
|
||||
options.allow_cpu_custom_calls =
|
||||
(platform_info.platform_id() == se::host::kHostPlatformId);
|
||||
options.device_allocator =
|
||||
GetAllocator(tf_allocator_adapter, ctx, platform_info);
|
||||
GetAllocator(tf_allocator_adapter, device, stream, platform_info);
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
options.shape_representation_fn =
|
||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||
|
@ -92,15 +92,19 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
|
||||
//
|
||||
// This is necessary because for XLA devices the underlying TF allocator returns
|
||||
// dummy tensors.
|
||||
//
|
||||
// `stream` parameter is nullable when running on host.
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
OpKernelContext* ctx, const XlaPlatformInfo& platform_info);
|
||||
DeviceBase* device, se::Stream* stream,
|
||||
const XlaPlatformInfo& platform_info);
|
||||
|
||||
// Returns created options for the XLA compiler, and writes the used allocator
|
||||
// into `tf_allocator_adapter`.
|
||||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
const XlaCompilationCache& cache, OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
const XlaCompilationCache& cache,
|
||||
const FunctionLibraryRuntime& function_library, DeviceBase* device,
|
||||
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user