Support GetMemoryInfo in TFRT.
PiperOrigin-RevId: 356285474 Change-Id: Ie7e9ab26c6e8a41f91f68fe10c0694afd0a00047
This commit is contained in:
parent
0921a80d56
commit
ed77a63244
@ -167,6 +167,9 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// Update the Eager Executor for current thread.
|
||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||
|
||||
// Return a list of local tensorflow::Device*.
|
||||
virtual std::vector<tensorflow::Device*> ListLocalTfDevices() = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are helper functions to assist integrating TFRT with current
|
||||
// TF eager runtime.
|
||||
|
@ -307,6 +307,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
return remote_device_manager_.GetOwned();
|
||||
}
|
||||
|
||||
std::vector<Device*> ListLocalTfDevices() override {
|
||||
return local_device_mgr()->ListDevices();
|
||||
}
|
||||
|
||||
// TODO(apassos) clean up RunMetadata storage.
|
||||
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
|
||||
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
|
||||
|
@ -526,9 +526,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
|
||||
m.def(
|
||||
"TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
|
||||
tensorflow::EagerContext* context = tensorflow::ContextFromInterface(
|
||||
auto* context =
|
||||
reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
|
||||
tensorflow::InputTFE_Context(ctx)));
|
||||
tensorflow::InputTFE_Context(ctx));
|
||||
|
||||
tensorflow::DeviceNameUtils::ParsedName input_device_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
|
||||
@ -539,7 +539,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
}
|
||||
|
||||
std::vector<tensorflow::Device*> devices =
|
||||
context->local_device_mgr()->ListDevices();
|
||||
context->ListLocalTfDevices();
|
||||
|
||||
tensorflow::Device* matched_device = nullptr;
|
||||
for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
|
||||
|
Loading…
Reference in New Issue
Block a user