Support GetMemoryInfo in TFRT.
PiperOrigin-RevId: 356285474 Change-Id: Ie7e9ab26c6e8a41f91f68fe10c0694afd0a00047
This commit is contained in:
parent
0921a80d56
commit
ed77a63244
@ -167,6 +167,9 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// Update the Eager Executor for current thread.
|
// Update the Eager Executor for current thread.
|
||||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||||
|
|
||||||
|
// Return a list of local tensorflow::Device*.
|
||||||
|
virtual std::vector<tensorflow::Device*> ListLocalTfDevices() = 0;
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Following are helper functions to assist integrating TFRT with current
|
// Following are helper functions to assist integrating TFRT with current
|
||||||
// TF eager runtime.
|
// TF eager runtime.
|
||||||
|
@ -307,6 +307,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
return remote_device_manager_.GetOwned();
|
return remote_device_manager_.GetOwned();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Device*> ListLocalTfDevices() override {
|
||||||
|
return local_device_mgr()->ListDevices();
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(apassos) clean up RunMetadata storage.
|
// TODO(apassos) clean up RunMetadata storage.
|
||||||
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
|
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
|
||||||
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
|
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
|
||||||
|
@ -526,9 +526,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
|
"TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
|
||||||
tensorflow::EagerContext* context = tensorflow::ContextFromInterface(
|
auto* context =
|
||||||
reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
|
reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
|
||||||
tensorflow::InputTFE_Context(ctx)));
|
tensorflow::InputTFE_Context(ctx));
|
||||||
|
|
||||||
tensorflow::DeviceNameUtils::ParsedName input_device_name;
|
tensorflow::DeviceNameUtils::ParsedName input_device_name;
|
||||||
if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
|
if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
|
||||||
@ -539,7 +539,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<tensorflow::Device*> devices =
|
std::vector<tensorflow::Device*> devices =
|
||||||
context->local_device_mgr()->ListDevices();
|
context->ListLocalTfDevices();
|
||||||
|
|
||||||
tensorflow::Device* matched_device = nullptr;
|
tensorflow::Device* matched_device = nullptr;
|
||||||
for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
|
for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user