[TF2XLA] [NFC] Simplify XlaPlatformInfoFromDevice

PiperOrigin-RevId: 328973906
Change-Id: I7e122bbdac7d196eee430c6dc0e1d76505f5126e
This commit is contained in:
George Karpenkov 2020-08-28 11:16:33 -07:00 committed by TensorFlower Gardener
parent b73c9f1236
commit 454ebebeab
5 changed files with 20 additions and 16 deletions

View File

@ -158,7 +158,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
constants_(constants), constants_(constants),
resources_(resources), resources_(resources),
function_(function), function_(function),
platform_info_(XlaPlatformInfoFromContext(ctx)), platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
has_ref_vars_(has_ref_vars) {} has_ref_vars_(has_ref_vars) {}
static Status CompileToLocalExecutable( static Status CompileToLocalExecutable(
@ -373,7 +373,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
constants_(ConstantsVector(ctx)), constants_(ConstantsVector(ctx)),
resources_(ResourcesVector(ctx)), resources_(ResourcesVector(ctx)),
function_(FunctionAttr(ctx)), function_(FunctionAttr(ctx)),
platform_info_(XlaPlatformInfoFromContext(ctx)), platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
must_compile_(MustCompileAttr(ctx)), must_compile_(MustCompileAttr(ctx)),
has_ref_vars_(HasRefVars(ctx)) {} has_ref_vars_(HasRefVars(ctx)) {}
@ -461,7 +461,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
} }
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {} : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
void XlaRunOp::Compute(OpKernelContext* ctx) { void XlaRunOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaRunOp " << def().name(); VLOG(3) << "XlaRunOp " << def().name();

View File

@ -37,7 +37,8 @@ namespace tensorflow {
class XlaCompileOnDemandOp : public OpKernel { class XlaCompileOnDemandOp : public OpKernel {
public: public:
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {} : OpKernel(ctx),
platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
void Compute(OpKernelContext* ctx) override; void Compute(OpKernelContext* ctx) override;
private: private:

View File

@ -94,6 +94,11 @@ class XlaDevice : public LocalDevice {
static Status GetMetadata(OpKernelConstruction* ctx, static Status GetMetadata(OpKernelConstruction* ctx,
const Metadata** metadata); const Metadata** metadata);
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
// `device`.
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
struct Options { struct Options {
// The StreamExecutor platform. Not owned. Must be non-null. // The StreamExecutor platform. Not owned. Must be non-null.
se::Platform* platform = nullptr; se::Platform* platform = nullptr;
@ -196,8 +201,6 @@ class XlaDevice : public LocalDevice {
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>> xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
Status MakeTensorFromProto(XlaDeviceContext* device_context, Status MakeTensorFromProto(XlaDeviceContext* device_context,
const TensorProto& tensor_proto, const TensorProto& tensor_proto,

View File

@ -75,21 +75,21 @@ Status BuildXlaCompilationCache(DeviceBase* device,
return Status::OK(); return Status::OK();
} }
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) { XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
DeviceType device_type = ctx->device_type(); auto device = static_cast<Device*>(device_base);
se::Platform::Id platform_id = nullptr; se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr; se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) { if (device->device_type() == DEVICE_CPU) {
platform_id = se::host::kHostPlatformId; platform_id = se::host::kHostPlatformId;
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { } else if (device->device_type() == DEVICE_GPU) {
platform_id = ctx->device() platform_id = device->tensorflow_gpu_device_info()
->tensorflow_gpu_device_info()
->stream->parent() ->stream->parent()
->platform() ->platform()
->id(); ->id();
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
.ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator // If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may // directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a // theoretically be more correct, but XLA returns a nice OOM message in a
@ -104,8 +104,8 @@ XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
xla_device_metadata->client()->backend().memory_allocator(); xla_device_metadata->client()->backend().memory_allocator();
} }
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
custom_allocator); xla_device_metadata, custom_allocator);
} }
se::DeviceMemoryAllocator* GetAllocator( se::DeviceMemoryAllocator* GetAllocator(

View File

@ -85,7 +85,7 @@ Status BuildXlaCompilationCache(DeviceBase* dev,
XlaCompilationCache** cache); XlaCompilationCache** cache);
// Returns information about the platform from kernel context. // Returns information about the platform from kernel context.
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx); XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
// Returns allocator from platform info if non-null, or populate and return a // Returns allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context. // pointer to the allocator adapter with allocator from context.