From ca096a329c28ad45c03a23339e34669fce0914ff Mon Sep 17 00:00:00 2001 From: James Keeling Date: Thu, 13 Dec 2018 03:16:24 -0800 Subject: [PATCH] TF Go Wrapper: Add String() method to Device This is useful when printing devices. PiperOrigin-RevId: 225343692 --- tensorflow/go/session.go | 9 +++++++++ tensorflow/go/session_test.go | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index db6ae4f26cd..bd4fd4f32f7 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -71,6 +71,15 @@ type Device struct { MemoryLimitBytes int64 } +// String describes d and implements fmt.Stringer. +func (d Device) String() string { + memStr := "no memory limit" + if d.MemoryLimitBytes >= 0 { + memStr = fmt.Sprintf("memory limit %d bytes", d.MemoryLimitBytes) + } + return fmt.Sprintf("(Device: name \"%s\", type %s, %s)", d.Name, d.Type, memStr) +} + // Return list of devices associated with a Session func (s *Session) ListDevices() ([]Device, error) { var devices []Device diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index 05ace99a238..c9bda001671 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -299,3 +299,21 @@ func TestListDevices(t *testing.T) { t.Fatalf("no devices detected") } } + +func TestDeviceString(t *testing.T) { + d := Device{Name: "foo", Type: "bar", MemoryLimitBytes: 12345} + got := d.String() + want := "(Device: name \"foo\", type bar, memory limit 12345 bytes)" + if got != want { + t.Errorf("Got \"%s\", want \"%s\"", got, want) + } +} + +func TestDeviceStringNoMemoryLimit(t *testing.T) { + d := Device{Name: "foo", Type: "bar", MemoryLimitBytes: -1} + got := d.String() + want := "(Device: name \"foo\", type bar, no memory limit)" + if got != want { + t.Errorf("Got \"%s\", want \"%s\"", got, want) + } +}