golang: added Session.ListDevices method (#14385)

* golang: added Session.ListDevices method
This commit is contained in:
Andrei Nigmatulin 2017-12-21 16:04:46 +00:00 committed by Shanqing Cai
parent f950ea836d
commit 9315a12cf7
2 changed files with 61 additions and 0 deletions

View File

@ -65,6 +65,51 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
return s, nil
}
// Device structure contains information about a device associated with a session, as returned by ListDevices()
type Device struct {
Name, Type string
MemoryLimitBytes int64
}
// Return list of devices associated with a Session
func (s *Session) ListDevices() ([]Device, error) {
var devices []Device
status := newStatus()
devices_list := C.TF_SessionListDevices(s.c, status.c)
if err := status.Err(); err != nil {
return nil, fmt.Errorf("SessionListDevices() failed: %v", err)
}
defer C.TF_DeleteDeviceList(devices_list)
for i := 0; i < int(C.TF_DeviceListCount(devices_list)); i++ {
device_name := C.TF_DeviceListName(devices_list, C.int(i), status.c)
if err := status.Err(); err != nil {
return nil, fmt.Errorf("DeviceListName(index=%d) failed: %v", i, err)
}
device_type := C.TF_DeviceListType(devices_list, C.int(i), status.c)
if err := status.Err(); err != nil {
return nil, fmt.Errorf("DeviceListType(index=%d) failed: %v", i, err)
}
memory_limit_bytes := C.TF_DeviceListMemoryBytes(devices_list, C.int(i), status.c)
if err := status.Err(); err != nil {
return nil, fmt.Errorf("DeviceListMemoryBytes(index=%d) failed: %v", i, err)
}
device := Device{
Name: C.GoString(device_name),
Type: C.GoString(device_type),
MemoryLimitBytes: int64(memory_limit_bytes),
}
devices = append(devices, device)
}
return devices, nil
}
// Run the graph with the associated session starting with the supplied feeds
// to compute the value of the requested fetches. Runs, but does not return
// Tensors for operations specified in targets.

View File

@ -283,3 +283,19 @@ func TestSessionConfig(t *testing.T) {
t.Fatalf("Got %v, want -1", output[0].Value())
}
}
func TestListDevices(t *testing.T) {
s, err := NewSession(NewGraph(), nil)
if err != nil {
t.Fatalf("NewSession(): %v", err)
}
devices, err := s.ListDevices()
if err != nil {
t.Fatalf("ListDevices(): %v", err)
}
if len(devices) == 0 {
t.Fatalf("no devices detected")
}
}