Support GetMemoryInfo in TFRT.

PiperOrigin-RevId: 356285474
Change-Id: Ie7e9ab26c6e8a41f91f68fe10c0694afd0a00047
This commit is contained in:
Chuanhao Zhuge 2021-02-08 09:39:20 -08:00 committed by TensorFlower Gardener
parent 0921a80d56
commit ed77a63244
3 changed files with 10 additions and 3 deletions

View File

@ -167,6 +167,9 @@ class ImmediateExecutionContext : public AbstractContext {
// Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
// Return a list of local tensorflow::Device*.
virtual std::vector<tensorflow::Device*> ListLocalTfDevices() = 0;
//===--------------------------------------------------------------------===//
// Following are helper functions to assist integrating TFRT with current
// TF eager runtime.

View File

@ -307,6 +307,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
return remote_device_manager_.GetOwned();
}
std::vector<Device*> ListLocalTfDevices() override {
return local_device_mgr()->ListDevices();
}
// TODO(apassos) clean up RunMetadata storage.
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);

View File

@ -526,9 +526,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def(
"TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
tensorflow::EagerContext* context = tensorflow::ContextFromInterface(
auto* context =
reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
tensorflow::InputTFE_Context(ctx)));
tensorflow::InputTFE_Context(ctx));
tensorflow::DeviceNameUtils::ParsedName input_device_name;
if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
@ -539,7 +539,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
}
std::vector<tensorflow::Device*> devices =
context->local_device_mgr()->ListDevices();
context->ListLocalTfDevices();
tensorflow::Device* matched_device = nullptr;
for (int device_idx = 0; device_idx < devices.size(); device_idx++) {