[TF2XLA] [NFC] Simplify XlaPlatformInfoFromDevice
PiperOrigin-RevId: 328973906 Change-Id: I7e122bbdac7d196eee430c6dc0e1d76505f5126e
This commit is contained in:
parent
b73c9f1236
commit
454ebebeab
@ -158,7 +158,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
|||||||
constants_(constants),
|
constants_(constants),
|
||||||
resources_(resources),
|
resources_(resources),
|
||||||
function_(function),
|
function_(function),
|
||||||
platform_info_(XlaPlatformInfoFromContext(ctx)),
|
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
|
||||||
has_ref_vars_(has_ref_vars) {}
|
has_ref_vars_(has_ref_vars) {}
|
||||||
|
|
||||||
static Status CompileToLocalExecutable(
|
static Status CompileToLocalExecutable(
|
||||||
@ -373,7 +373,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
|
|||||||
constants_(ConstantsVector(ctx)),
|
constants_(ConstantsVector(ctx)),
|
||||||
resources_(ResourcesVector(ctx)),
|
resources_(ResourcesVector(ctx)),
|
||||||
function_(FunctionAttr(ctx)),
|
function_(FunctionAttr(ctx)),
|
||||||
platform_info_(XlaPlatformInfoFromContext(ctx)),
|
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
|
||||||
must_compile_(MustCompileAttr(ctx)),
|
must_compile_(MustCompileAttr(ctx)),
|
||||||
has_ref_vars_(HasRefVars(ctx)) {}
|
has_ref_vars_(HasRefVars(ctx)) {}
|
||||||
|
|
||||||
@ -461,7 +461,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
|
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
|
||||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
|
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
|
||||||
|
|
||||||
void XlaRunOp::Compute(OpKernelContext* ctx) {
|
void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||||
VLOG(3) << "XlaRunOp " << def().name();
|
VLOG(3) << "XlaRunOp " << def().name();
|
||||||
|
@ -37,7 +37,8 @@ namespace tensorflow {
|
|||||||
class XlaCompileOnDemandOp : public OpKernel {
|
class XlaCompileOnDemandOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
|
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
|
||||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
|
: OpKernel(ctx),
|
||||||
|
platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
|
||||||
void Compute(OpKernelContext* ctx) override;
|
void Compute(OpKernelContext* ctx) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -94,6 +94,11 @@ class XlaDevice : public LocalDevice {
|
|||||||
static Status GetMetadata(OpKernelConstruction* ctx,
|
static Status GetMetadata(OpKernelConstruction* ctx,
|
||||||
const Metadata** metadata);
|
const Metadata** metadata);
|
||||||
|
|
||||||
|
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
|
||||||
|
// `device`.
|
||||||
|
static Status GetMetadataFromDevice(DeviceBase* device,
|
||||||
|
const XlaDevice::Metadata** metadata);
|
||||||
|
|
||||||
struct Options {
|
struct Options {
|
||||||
// The StreamExecutor platform. Not owned. Must be non-null.
|
// The StreamExecutor platform. Not owned. Must be non-null.
|
||||||
se::Platform* platform = nullptr;
|
se::Platform* platform = nullptr;
|
||||||
@ -196,8 +201,6 @@ class XlaDevice : public LocalDevice {
|
|||||||
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
|
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
|
||||||
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
static Status GetMetadataFromDevice(DeviceBase* device,
|
|
||||||
const XlaDevice::Metadata** metadata);
|
|
||||||
|
|
||||||
Status MakeTensorFromProto(XlaDeviceContext* device_context,
|
Status MakeTensorFromProto(XlaDeviceContext* device_context,
|
||||||
const TensorProto& tensor_proto,
|
const TensorProto& tensor_proto,
|
||||||
|
@ -75,21 +75,21 @@ Status BuildXlaCompilationCache(DeviceBase* device,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
|
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
|
||||||
DeviceType device_type = ctx->device_type();
|
auto device = static_cast<Device*>(device_base);
|
||||||
se::Platform::Id platform_id = nullptr;
|
se::Platform::Id platform_id = nullptr;
|
||||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||||
se::DeviceMemoryAllocator* custom_allocator = nullptr;
|
se::DeviceMemoryAllocator* custom_allocator = nullptr;
|
||||||
|
|
||||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
if (device->device_type() == DEVICE_CPU) {
|
||||||
platform_id = se::host::kHostPlatformId;
|
platform_id = se::host::kHostPlatformId;
|
||||||
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
|
} else if (device->device_type() == DEVICE_GPU) {
|
||||||
platform_id = ctx->device()
|
platform_id = device->tensorflow_gpu_device_info()
|
||||||
->tensorflow_gpu_device_info()
|
|
||||||
->stream->parent()
|
->stream->parent()
|
||||||
->platform()
|
->platform()
|
||||||
->id();
|
->id();
|
||||||
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
|
} else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
|
||||||
|
.ok()) {
|
||||||
// If we are on an XlaDevice, use the underlying XLA platform's allocator
|
// If we are on an XlaDevice, use the underlying XLA platform's allocator
|
||||||
// directly. We could use the StreamExecutor's allocator which may
|
// directly. We could use the StreamExecutor's allocator which may
|
||||||
// theoretically be more correct, but XLA returns a nice OOM message in a
|
// theoretically be more correct, but XLA returns a nice OOM message in a
|
||||||
@ -104,8 +104,8 @@ XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
|
|||||||
xla_device_metadata->client()->backend().memory_allocator();
|
xla_device_metadata->client()->backend().memory_allocator();
|
||||||
}
|
}
|
||||||
|
|
||||||
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
|
||||||
custom_allocator);
|
xla_device_metadata, custom_allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
se::DeviceMemoryAllocator* GetAllocator(
|
se::DeviceMemoryAllocator* GetAllocator(
|
||||||
|
@ -85,7 +85,7 @@ Status BuildXlaCompilationCache(DeviceBase* dev,
|
|||||||
XlaCompilationCache** cache);
|
XlaCompilationCache** cache);
|
||||||
|
|
||||||
// Returns information about the platform from kernel context.
|
// Returns information about the platform from kernel context.
|
||||||
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
|
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
|
||||||
|
|
||||||
// Returns allocator from platform info if non-null, or populate and return a
|
// Returns allocator from platform info if non-null, or populate and return a
|
||||||
// pointer to the allocator adapter with allocator from context.
|
// pointer to the allocator adapter with allocator from context.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user