diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 9e1fa7dedfa..79f1e47d98b 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -191,7 +191,9 @@ static Status CompileToLocalExecutable( absl::optional tf_allocator_adapter; XlaCompiler::Options options = GenerateCompilerOptions( - *cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter); + *cache, *ctx->function_library(), ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info, has_ref_vars, &tf_allocator_adapter); std::map constant_args; for (int i : constants) { @@ -248,8 +250,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); int device_ordinal = stream ? stream->parent()->device_ordinal() : client->default_device_ordinal(); XlaComputationLaunchContext launch_context( @@ -472,8 +476,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { XlaExecutableClosureStore::Global()->Consume(key); absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; int device_ordinal = stream ? stream->parent()->device_ordinal() diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index c99d3b6bb7c..ba20b532a11 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -49,8 +49,10 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, xla::LocalClient* client = static_cast(cache->client()); absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); XlaComputationLaunchContext launch_context( client, allocator, client->default_device_ordinal(), /*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr, @@ -162,9 +164,11 @@ Status XlaCompileOnDemandOp::Compile( })); absl::optional tf_allocator_adapter; - XlaCompiler::Options options = - GenerateCompilerOptions(**cache, ctx, platform_info_, - /*has_ref_vars=*/true, &tf_allocator_adapter); + XlaCompiler::Options options = GenerateCompilerOptions( + **cache, *ctx->function_library(), ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_, + /*has_ref_vars=*/true, &tf_allocator_adapter); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index aaebda88810..b38bf9282b1 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -110,41 +110,40 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { se::DeviceMemoryAllocator* GetAllocator( absl::optional* tf_allocator_adapter, - OpKernelContext* ctx, const XlaPlatformInfo& platform_info) { + DeviceBase* device, se::Stream* stream, + const XlaPlatformInfo& platform_info) { if (platform_info.custom_allocator()) { return platform_info.custom_allocator(); } - if (!ctx->op_device_context()) { + if (!stream) { // Stream is not set for the host platform. se::Platform* platform = se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) .ValueOrDie(); - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform); + tf_allocator_adapter->emplace(device->GetAllocator({}), platform); return &tf_allocator_adapter->value(); } - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), - ctx->op_device_context()->stream()); + tf_allocator_adapter->emplace(device->GetAllocator({}), stream); return &tf_allocator_adapter->value(); } XlaCompiler::Options GenerateCompilerOptions( - const XlaCompilationCache& cache, OpKernelContext* ctx, - const XlaPlatformInfo& platform_info, bool has_ref_vars, + const XlaCompilationCache& cache, + const FunctionLibraryRuntime& function_library, DeviceBase* device, + se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, absl::optional* tf_allocator_adapter) { - CHECK(ctx->function_library()); XlaCompiler::Options options; options.client = static_cast(cache.client()); - if (ctx->op_device_context() != nullptr) { - options.device_ordinal = - ctx->op_device_context()->stream()->parent()->device_ordinal(); + if (stream != nullptr) { + options.device_ordinal = stream->parent()->device_ordinal(); } options.device_type = cache.device_type(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.graph_def_version = ctx->function_library()->graph_def_version(); + options.flib_def = function_library.GetFunctionLibraryDefinition(); + options.graph_def_version = function_library.graph_def_version(); options.allow_cpu_custom_calls = (platform_info.platform_id() == se::host::kHostPlatformId); options.device_allocator = - GetAllocator(tf_allocator_adapter, ctx, platform_info); + GetAllocator(tf_allocator_adapter, device, stream, platform_info); if (platform_info.xla_device_metadata()) { options.shape_representation_fn = platform_info.xla_device_metadata()->shape_representation_fn(); diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h index c4e0ec3f7da..bfb438cc398 100644 --- a/tensorflow/compiler/jit/xla_platform_info.h +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -92,15 +92,19 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); // // This is necessary because for XLA devices the underlying TF allocator returns // dummy tensors. +// +// `stream` parameter is nullable when running on host. se::DeviceMemoryAllocator* GetAllocator( absl::optional* tf_allocator_adapter, - OpKernelContext* ctx, const XlaPlatformInfo& platform_info); + DeviceBase* device, se::Stream* stream, + const XlaPlatformInfo& platform_info); // Returns created options for the XLA compiler, and writes the used allocator // into `tf_allocator_adapter`. XlaCompiler::Options GenerateCompilerOptions( - const XlaCompilationCache& cache, OpKernelContext* ctx, - const XlaPlatformInfo& platform_info, bool has_ref_vars, + const XlaCompilationCache& cache, + const FunctionLibraryRuntime& function_library, DeviceBase* device, + se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, absl::optional* tf_allocator_adapter); } // namespace tensorflow