golang: added Session.ListDevices method (#14385)
* golang: added Session.ListDevices method
This commit is contained in:
parent
f950ea836d
commit
9315a12cf7
@ -65,6 +65,51 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Device structure contains information about a device associated with a session, as returned by ListDevices()
|
||||
type Device struct {
|
||||
Name, Type string
|
||||
MemoryLimitBytes int64
|
||||
}
|
||||
|
||||
// Return list of devices associated with a Session
|
||||
func (s *Session) ListDevices() ([]Device, error) {
|
||||
var devices []Device
|
||||
|
||||
status := newStatus()
|
||||
devices_list := C.TF_SessionListDevices(s.c, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, fmt.Errorf("SessionListDevices() failed: %v", err)
|
||||
}
|
||||
defer C.TF_DeleteDeviceList(devices_list)
|
||||
|
||||
for i := 0; i < int(C.TF_DeviceListCount(devices_list)); i++ {
|
||||
device_name := C.TF_DeviceListName(devices_list, C.int(i), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, fmt.Errorf("DeviceListName(index=%d) failed: %v", i, err)
|
||||
}
|
||||
|
||||
device_type := C.TF_DeviceListType(devices_list, C.int(i), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, fmt.Errorf("DeviceListType(index=%d) failed: %v", i, err)
|
||||
}
|
||||
|
||||
memory_limit_bytes := C.TF_DeviceListMemoryBytes(devices_list, C.int(i), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, fmt.Errorf("DeviceListMemoryBytes(index=%d) failed: %v", i, err)
|
||||
}
|
||||
|
||||
device := Device{
|
||||
Name: C.GoString(device_name),
|
||||
Type: C.GoString(device_type),
|
||||
MemoryLimitBytes: int64(memory_limit_bytes),
|
||||
}
|
||||
|
||||
devices = append(devices, device)
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
// Run the graph with the associated session starting with the supplied feeds
|
||||
// to compute the value of the requested fetches. Runs, but does not return
|
||||
// Tensors for operations specified in targets.
|
||||
|
@ -283,3 +283,19 @@ func TestSessionConfig(t *testing.T) {
|
||||
t.Fatalf("Got %v, want -1", output[0].Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDevices(t *testing.T) {
|
||||
s, err := NewSession(NewGraph(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewSession(): %v", err)
|
||||
}
|
||||
|
||||
devices, err := s.ListDevices()
|
||||
if err != nil {
|
||||
t.Fatalf("ListDevices(): %v", err)
|
||||
}
|
||||
|
||||
if len(devices) == 0 {
|
||||
t.Fatalf("no devices detected")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user