[TF2XLA] [NFC] Simplify XlaPlatformInfoFromDevice
PiperOrigin-RevId: 328973906 Change-Id: I7e122bbdac7d196eee430c6dc0e1d76505f5126e
This commit is contained in:
parent
b73c9f1236
commit
454ebebeab
@ -158,7 +158,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
constants_(constants),
|
||||
resources_(resources),
|
||||
function_(function),
|
||||
platform_info_(XlaPlatformInfoFromContext(ctx)),
|
||||
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
|
||||
has_ref_vars_(has_ref_vars) {}
|
||||
|
||||
static Status CompileToLocalExecutable(
|
||||
@ -373,7 +373,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
|
||||
constants_(ConstantsVector(ctx)),
|
||||
resources_(ResourcesVector(ctx)),
|
||||
function_(FunctionAttr(ctx)),
|
||||
platform_info_(XlaPlatformInfoFromContext(ctx)),
|
||||
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
|
||||
must_compile_(MustCompileAttr(ctx)),
|
||||
has_ref_vars_(HasRefVars(ctx)) {}
|
||||
|
||||
@ -461,7 +461,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
|
||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
|
||||
|
||||
void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaRunOp " << def().name();
|
||||
|
@ -37,7 +37,8 @@ namespace tensorflow {
|
||||
class XlaCompileOnDemandOp : public OpKernel {
|
||||
public:
|
||||
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
|
||||
: OpKernel(ctx),
|
||||
platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
|
@ -94,6 +94,11 @@ class XlaDevice : public LocalDevice {
|
||||
static Status GetMetadata(OpKernelConstruction* ctx,
|
||||
const Metadata** metadata);
|
||||
|
||||
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
|
||||
// `device`.
|
||||
static Status GetMetadataFromDevice(DeviceBase* device,
|
||||
const XlaDevice::Metadata** metadata);
|
||||
|
||||
struct Options {
|
||||
// The StreamExecutor platform. Not owned. Must be non-null.
|
||||
se::Platform* platform = nullptr;
|
||||
@ -196,8 +201,6 @@ class XlaDevice : public LocalDevice {
|
||||
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
|
||||
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
static Status GetMetadataFromDevice(DeviceBase* device,
|
||||
const XlaDevice::Metadata** metadata);
|
||||
|
||||
Status MakeTensorFromProto(XlaDeviceContext* device_context,
|
||||
const TensorProto& tensor_proto,
|
||||
|
@ -75,21 +75,21 @@ Status BuildXlaCompilationCache(DeviceBase* device,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
DeviceType device_type = ctx->device_type();
|
||||
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
|
||||
auto device = static_cast<Device*>(device_base);
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
se::DeviceMemoryAllocator* custom_allocator = nullptr;
|
||||
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
if (device->device_type() == DEVICE_CPU) {
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
|
||||
platform_id = ctx->device()
|
||||
->tensorflow_gpu_device_info()
|
||||
} else if (device->device_type() == DEVICE_GPU) {
|
||||
platform_id = device->tensorflow_gpu_device_info()
|
||||
->stream->parent()
|
||||
->platform()
|
||||
->id();
|
||||
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
|
||||
} else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
|
||||
.ok()) {
|
||||
// If we are on an XlaDevice, use the underlying XLA platform's allocator
|
||||
// directly. We could use the StreamExecutor's allocator which may
|
||||
// theoretically be more correct, but XLA returns a nice OOM message in a
|
||||
@ -104,8 +104,8 @@ XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
xla_device_metadata->client()->backend().memory_allocator();
|
||||
}
|
||||
|
||||
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
custom_allocator);
|
||||
return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
|
||||
xla_device_metadata, custom_allocator);
|
||||
}
|
||||
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
|
@ -85,7 +85,7 @@ Status BuildXlaCompilationCache(DeviceBase* dev,
|
||||
XlaCompilationCache** cache);
|
||||
|
||||
// Returns information about the platform from kernel context.
|
||||
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
|
||||
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
|
||||
|
||||
// Returns allocator from platform info if non-null, or populate and return a
|
||||
// pointer to the allocator adapter with allocator from context.
|
||||
|
Loading…
Reference in New Issue
Block a user