diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index fc914f86df3..db6ae4f26cd 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -65,6 +65,51 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) { return s, nil } +// Device structure contains information about a device associated with a session, as returned by ListDevices() +type Device struct { + Name, Type string + MemoryLimitBytes int64 +} + +// Return list of devices associated with a Session +func (s *Session) ListDevices() ([]Device, error) { + var devices []Device + + status := newStatus() + devices_list := C.TF_SessionListDevices(s.c, status.c) + if err := status.Err(); err != nil { + return nil, fmt.Errorf("SessionListDevices() failed: %v", err) + } + defer C.TF_DeleteDeviceList(devices_list) + + for i := 0; i < int(C.TF_DeviceListCount(devices_list)); i++ { + device_name := C.TF_DeviceListName(devices_list, C.int(i), status.c) + if err := status.Err(); err != nil { + return nil, fmt.Errorf("DeviceListName(index=%d) failed: %v", i, err) + } + + device_type := C.TF_DeviceListType(devices_list, C.int(i), status.c) + if err := status.Err(); err != nil { + return nil, fmt.Errorf("DeviceListType(index=%d) failed: %v", i, err) + } + + memory_limit_bytes := C.TF_DeviceListMemoryBytes(devices_list, C.int(i), status.c) + if err := status.Err(); err != nil { + return nil, fmt.Errorf("DeviceListMemoryBytes(index=%d) failed: %v", i, err) + } + + device := Device{ + Name: C.GoString(device_name), + Type: C.GoString(device_type), + MemoryLimitBytes: int64(memory_limit_bytes), + } + + devices = append(devices, device) + } + + return devices, nil +} + // Run the graph with the associated session starting with the supplied feeds // to compute the value of the requested fetches. Runs, but does not return // Tensors for operations specified in targets. diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index 73d78a8e577..05ace99a238 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -283,3 +283,19 @@ func TestSessionConfig(t *testing.T) { t.Fatalf("Got %v, want -1", output[0].Value()) } } + +func TestListDevices(t *testing.T) { + s, err := NewSession(NewGraph(), nil) + if err != nil { + t.Fatalf("NewSession(): %v", err) + } + + devices, err := s.ListDevices() + if err != nil { + t.Fatalf("ListDevices(): %v", err) + } + + if len(devices) == 0 { + t.Fatalf("no devices detected") + } +}