TF Go Wrapper: Add String() method to Device

This is useful when printing devices.

PiperOrigin-RevId: 225343692
This commit is contained in:
James Keeling 2018-12-13 03:16:24 -08:00 committed by TensorFlower Gardener
parent 5dfb096b53
commit ca096a329c
2 changed files with 27 additions and 0 deletions

View File

@ -71,6 +71,15 @@ type Device struct {
MemoryLimitBytes int64
}
// String describes d and implements fmt.Stringer.
func (d Device) String() string {
memStr := "no memory limit"
if d.MemoryLimitBytes >= 0 {
memStr = fmt.Sprintf("memory limit %d bytes", d.MemoryLimitBytes)
}
return fmt.Sprintf("(Device: name \"%s\", type %s, %s)", d.Name, d.Type, memStr)
}
// Return list of devices associated with a Session
func (s *Session) ListDevices() ([]Device, error) {
var devices []Device

View File

@ -299,3 +299,21 @@ func TestListDevices(t *testing.T) {
t.Fatalf("no devices detected")
}
}
func TestDeviceString(t *testing.T) {
d := Device{Name: "foo", Type: "bar", MemoryLimitBytes: 12345}
got := d.String()
want := "(Device: name \"foo\", type bar, memory limit 12345 bytes)"
if got != want {
t.Errorf("Got \"%s\", want \"%s\"", got, want)
}
}
func TestDeviceStringNoMemoryLimit(t *testing.T) {
d := Device{Name: "foo", Type: "bar", MemoryLimitBytes: -1}
got := d.String()
want := "(Device: name \"foo\", type bar, no memory limit)"
if got != want {
t.Errorf("Got \"%s\", want \"%s\"", got, want)
}
}