Merge pull request #43812 from ppiskorski:get_device_context
PiperOrigin-RevId: 338777432
This commit is contained in:
commit
ad052ba4cc
@ -53,6 +53,8 @@ class CompositeDevice : public Device {
|
|||||||
const std::vector<string>& underlying_devices, const string& device_name,
|
const std::vector<string>& underlying_devices, const string& device_name,
|
||||||
Status* status);
|
Status* status);
|
||||||
|
|
||||||
|
bool IsRemoteCallAllowed() const override { return false; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CompositeDevice(const DeviceAttributes& device_attributes,
|
CompositeDevice(const DeviceAttributes& device_attributes,
|
||||||
const std::vector<string>& underlying_devices)
|
const std::vector<string>& underlying_devices)
|
||||||
|
|||||||
@ -200,13 +200,15 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
|
|||||||
// "TPU_SYSTEM" indicates that `device` is a CPU.
|
// "TPU_SYSTEM" indicates that `device` is a CPU.
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
if (device_type == "GPU" || device_type == "TPU") {
|
|
||||||
|
if (device->IsRemoteCallAllowed()) {
|
||||||
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
|
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
|
||||||
if (dev_info) {
|
if (dev_info) {
|
||||||
*device_context = dev_info->default_context;
|
*device_context = dev_info->default_context;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors::Internal("Device type: ", device_type,
|
return errors::Internal("Device type: ", device_type,
|
||||||
" is currently unsupported for remote ",
|
" is currently unsupported for remote ",
|
||||||
"function executions");
|
"function executions");
|
||||||
|
|||||||
@ -149,6 +149,10 @@ class RenamedDevice : public Device {
|
|||||||
|
|
||||||
bool IsLocal() const override { return underlying_device_->IsLocal(); }
|
bool IsLocal() const override { return underlying_device_->IsLocal(); }
|
||||||
|
|
||||||
|
bool IsRemoteCallAllowed() const override {
|
||||||
|
return underlying_device_->IsRemoteCallAllowed();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
|
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
|
||||||
bool owns_underlying, bool isolate_session_state,
|
bool owns_underlying, bool isolate_session_state,
|
||||||
|
|||||||
@ -50,6 +50,8 @@ class RemoteDevice : public Device {
|
|||||||
|
|
||||||
bool IsLocal() const override { return false; }
|
bool IsLocal() const override { return false; }
|
||||||
|
|
||||||
|
bool IsRemoteCallAllowed() const override { return true; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const string local_dev_name_;
|
const string local_dev_name_;
|
||||||
|
|
||||||
|
|||||||
@ -54,4 +54,21 @@ DeviceAttributes Device::BuildDeviceAttributes(
|
|||||||
return da;
|
return da;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Device::IsRemoteCallAllowed() const {
|
||||||
|
auto& type = parsed_name_.type;
|
||||||
|
if (type == "TPU") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (type == "TPU_SYSTEM") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (type == "CPU") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (type == "GPU") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -178,6 +178,9 @@ class Device : public DeviceBase {
|
|||||||
|
|
||||||
virtual bool IsLocal() const { return true; }
|
virtual bool IsLocal() const { return true; }
|
||||||
|
|
||||||
|
// Informs if this Device can be used as a caller in RemoteCall operation.
|
||||||
|
virtual bool IsRemoteCallAllowed() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void DeleteResourceMgr() {
|
void DeleteResourceMgr() {
|
||||||
delete rmgr_;
|
delete rmgr_;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user