[TF2XLA] [NFC] Simplify XlaPlatformInfoFromDevice

PiperOrigin-RevId: 328973906
Change-Id: I7e122bbdac7d196eee430c6dc0e1d76505f5126e
This commit is contained in:
George Karpenkov 2020-08-28 11:16:33 -07:00 committed by TensorFlower Gardener
parent b73c9f1236
commit 454ebebeab
5 changed files with 20 additions and 16 deletions

View File

@ -158,7 +158,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
constants_(constants),
resources_(resources),
function_(function),
platform_info_(XlaPlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
has_ref_vars_(has_ref_vars) {}
static Status CompileToLocalExecutable(
@ -373,7 +373,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
constants_(ConstantsVector(ctx)),
resources_(ResourcesVector(ctx)),
function_(FunctionAttr(ctx)),
platform_info_(XlaPlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
must_compile_(MustCompileAttr(ctx)),
has_ref_vars_(HasRefVars(ctx)) {}
@ -461,7 +461,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
}
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
void XlaRunOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaRunOp " << def().name();

View File

@ -37,7 +37,8 @@ namespace tensorflow {
class XlaCompileOnDemandOp : public OpKernel {
public:
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
: OpKernel(ctx),
platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
void Compute(OpKernelContext* ctx) override;
private:

View File

@ -94,6 +94,11 @@ class XlaDevice : public LocalDevice {
static Status GetMetadata(OpKernelConstruction* ctx,
const Metadata** metadata);
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
// `device`.
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
struct Options {
// The StreamExecutor platform. Not owned. Must be non-null.
se::Platform* platform = nullptr;
@ -196,8 +201,6 @@ class XlaDevice : public LocalDevice {
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
Status MakeTensorFromProto(XlaDeviceContext* device_context,
const TensorProto& tensor_proto,

View File

@ -75,21 +75,21 @@ Status BuildXlaCompilationCache(DeviceBase* device,
return Status::OK();
}
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
auto device = static_cast<Device*>(device_base);
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
if (device->device_type() == DEVICE_CPU) {
platform_id = se::host::kHostPlatformId;
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
platform_id = ctx->device()
->tensorflow_gpu_device_info()
} else if (device->device_type() == DEVICE_GPU) {
platform_id = device->tensorflow_gpu_device_info()
->stream->parent()
->platform()
->id();
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
} else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
.ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a
@ -104,8 +104,8 @@ XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
xla_device_metadata->client()->backend().memory_allocator();
}
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
custom_allocator);
return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
xla_device_metadata, custom_allocator);
}
se::DeviceMemoryAllocator* GetAllocator(

View File

@ -85,7 +85,7 @@ Status BuildXlaCompilationCache(DeviceBase* dev,
XlaCompilationCache** cache);
// Returns information about the platform from kernel context.
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
// Returns allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.