diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index f755e9d4f8b..c50e5e30aa6 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -108,6 +108,35 @@ func NewTensor(value interface{}) (*Tensor, error) { return t, nil } +// ReadTensor constructs a Tensor with the provided type and shape from the +// serialized tensor contents in r. +// +// See also WriteContentsTo. +func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) { + if err := isTensorSerializable(dataType); err != nil { + return nil, err + } + nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape)) + var shapePtr *C.int64_t + if len(shape) > 0 { + shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) + } + t := &Tensor{ + c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), + shape: shape, + } + runtime.SetFinalizer(t, (*Tensor).finalize) + raw := tensorData(t.c) + n, err := r.Read(raw) + if err != nil { + return nil, err + } + if uintptr(n) != nbytes { + return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n) + } + return t, nil +} + // newTensorFromC takes ownership of c and returns the owning Tensor. func newTensorFromC(c *C.TF_Tensor) *Tensor { var shape []int64 @@ -156,6 +185,21 @@ func (t *Tensor) Value() interface{} { return reflect.Indirect(val).Interface() } +// WriteContentsTo writes the serialized contents of t to w. +// +// Returns the number of bytes written. See ReadTensor for +// reconstructing a Tensor from the serialized form. +// +// WARNING: WriteContentsTo is not comprehensive and will fail +// if t.DataType() is non-numeric (e.g., String). See +// https://github.com/tensorflow/tensorflow/issues/6003. +func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) { + if err := isTensorSerializable(t.DataType()); err != nil { + return 0, err + } + return io.Copy(w, bytes.NewReader(tensorData(t.c))) +} + func tensorData(c *C.TF_Tensor) []byte { // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices cbytes := C.TF_TensorData(c) @@ -385,6 +429,23 @@ func bug(format string, args ...interface{}) error { return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...)) } +func isTensorSerializable(dataType DataType) error { + // For numeric types, the serialized Tensor matches the in-memory + // representation. See the implementation of Tensor::AsProtoContent in + // https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc + // + // The more appropriate way to be in sync with Tensor::AsProtoContent + // would be to have the TensorFlow C library export functions for + // serialization and deserialization of Tensors. Till then capitalize + // on knowledge of the implementation for numeric types. + switch dataType { + case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half: + return nil + default: + return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType) + } +} + // nativeEndian is the byte order for the local platform. Used to send back and // forth Tensors with the C API. We test for endianness at runtime because // some architectures can be booted into different endian modes. diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 073da0cc6e4..2a3ed416bdb 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -15,6 +15,7 @@ package tensorflow import ( + "bytes" "reflect" "testing" ) @@ -28,7 +29,6 @@ func TestNewTensor(t *testing.T) { {nil, int16(5)}, {nil, int32(5)}, {nil, int64(5)}, - {nil, int64(5)}, {nil, uint8(5)}, {nil, uint16(5)}, {nil, float32(5)}, @@ -103,6 +103,114 @@ func TestNewTensor(t *testing.T) { } } +func TestTensorSerialization(t *testing.T) { + var tests = []interface{}{ + int8(5), + int16(5), + int32(5), + int64(5), + uint8(5), + uint16(5), + float32(5), + float64(5), + complex(float32(5), float32(6)), + complex(float64(5), float64(6)), + []float64{1}, + [][]float32{{1, 2}, {3, 4}, {5, 6}}, + [][][]int8{ + {{1, 2}, {3, 4}, {5, 6}}, + {{7, 8}, {9, 10}, {11, 12}}, + {{0, -1}, {-2, -3}, {-4, -5}}, + {{-6, -7}, {-8, -9}, {-10, -11}}, + }, + } + for _, v := range tests { + t1, err := NewTensor(v) + if err != nil { + t.Errorf("(%v): %v", v, err) + continue + } + buf := new(bytes.Buffer) + n, err := t1.WriteContentsTo(buf) + if err != nil { + t.Errorf("(%v): %v", v, err) + continue + } + if n != int64(buf.Len()) { + t.Errorf("(%v): WriteContentsTo said it wrote %v bytes, but wrote %v", v, n, buf.Len()) + } + t2, err := ReadTensor(t1.DataType(), t1.Shape(), buf) + if err != nil { + t.Errorf("(%v): %v", v, err) + continue + } + if buf.Len() != 0 { + t.Errorf("(%v): %v bytes written by WriteContentsTo not read by ReadTensor", v, buf.Len()) + } + if got, want := t2.DataType(), t1.DataType(); got != want { + t.Errorf("(%v): Got %v, want %v", v, got, want) + } + if got, want := t2.Shape(), t1.Shape(); !reflect.DeepEqual(got, want) { + t.Errorf("(%v): Got %v, want %v", v, got, want) + } + if got, want := t2.Value(), v; !reflect.DeepEqual(got, want) { + t.Errorf("(%v): Got %v, want %v", v, got, want) + } + } +} + +func TestReadTensorDoesNotReadBeyondContent(t *testing.T) { + t1, _ := NewTensor(int8(7)) + t2, _ := NewTensor(float32(2.718)) + buf := new(bytes.Buffer) + if _, err := t1.WriteContentsTo(buf); err != nil { + t.Fatal(err) + } + if _, err := t2.WriteContentsTo(buf); err != nil { + t.Fatal(err) + } + + t3, err := ReadTensor(t1.DataType(), t1.Shape(), buf) + if err != nil { + t.Fatal(err) + } + t4, err := ReadTensor(t2.DataType(), t2.Shape(), buf) + if err != nil { + t.Fatal(err) + } + + if v, ok := t3.Value().(int8); !ok || v != 7 { + t.Errorf("Got (%v (%T), %v), want (7 (int8), true)", v, v, ok) + } + if v, ok := t4.Value().(float32); !ok || v != 2.718 { + t.Errorf("Got (%v (%T), %v), want (2.718 (float32), true)", v, v, ok) + } +} + +func TestTensorSerializationErrors(t *testing.T) { + // String tensors cannot be serialized + t1, err := NewTensor("abcd") + if err != nil { + t.Fatal(err) + } + buf := new(bytes.Buffer) + if n, err := t1.WriteContentsTo(buf); n != 0 || err == nil || buf.Len() != 0 { + t.Errorf("Got (%v, %v, %v) want (0, , 0)", n, err, buf.Len()) + } + // Should fail to read a truncated value. + if t1, err = NewTensor(int8(8)); err != nil { + t.Fatal(err) + } + n, err := t1.WriteContentsTo(buf) + if err != nil { + t.Fatal(err) + } + r := bytes.NewReader(buf.Bytes()[:n-1]) + if _, err = ReadTensor(t1.DataType(), t1.Shape(), r); err == nil { + t.Error("ReadTensor should have failed if the tensor content was truncated") + } +} + func benchmarkNewTensor(b *testing.B, v interface{}) { for i := 0; i < b.N; i++ { if t, err := NewTensor(v); err != nil || t == nil {