Go: API functions to serialize/deserialize Tensors.

The current implementation has limitations, in particular
it does not support string tensors. But the same API
will be sufficient when that limitation has been
addressed.

See also #6003. The API added here can be used to
fill in the tensor_contents field of a TensorProto
protocol buffer.
Change: 140760816
This commit is contained in:
Asim Shankar 2016-12-01 12:10:45 -08:00 committed by TensorFlower Gardener
parent c4d507ac75
commit 0ba830748d
2 changed files with 170 additions and 1 deletions

View File

@ -108,6 +108,35 @@ func NewTensor(value interface{}) (*Tensor, error) {
return t, nil
}
// ReadTensor constructs a Tensor with the provided type and shape from the
// serialized tensor contents in r.
//
// See also WriteContentsTo.
func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) {
if err := isTensorSerializable(dataType); err != nil {
return nil, err
}
nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
var shapePtr *C.int64_t
if len(shape) > 0 {
shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
}
t := &Tensor{
c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
shape: shape,
}
runtime.SetFinalizer(t, (*Tensor).finalize)
raw := tensorData(t.c)
n, err := r.Read(raw)
if err != nil {
return nil, err
}
if uintptr(n) != nbytes {
return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n)
}
return t, nil
}
// newTensorFromC takes ownership of c and returns the owning Tensor.
func newTensorFromC(c *C.TF_Tensor) *Tensor {
var shape []int64
@ -156,6 +185,21 @@ func (t *Tensor) Value() interface{} {
return reflect.Indirect(val).Interface()
}
// WriteContentsTo writes the serialized contents of t to w.
//
// Returns the number of bytes written. See ReadTensor for
// reconstructing a Tensor from the serialized form.
//
// WARNING: WriteContentsTo is not comprehensive and will fail
// if t.DataType() is non-numeric (e.g., String). See
// https://github.com/tensorflow/tensorflow/issues/6003.
func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) {
if err := isTensorSerializable(t.DataType()); err != nil {
return 0, err
}
return io.Copy(w, bytes.NewReader(tensorData(t.c)))
}
func tensorData(c *C.TF_Tensor) []byte {
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
cbytes := C.TF_TensorData(c)
@ -385,6 +429,23 @@ func bug(format string, args ...interface{}) error {
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
}
func isTensorSerializable(dataType DataType) error {
// For numeric types, the serialized Tensor matches the in-memory
// representation. See the implementation of Tensor::AsProtoContent in
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc
//
// The more appropriate way to be in sync with Tensor::AsProtoContent
// would be to have the TensorFlow C library export functions for
// serialization and deserialization of Tensors. Till then capitalize
// on knowledge of the implementation for numeric types.
switch dataType {
case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half:
return nil
default:
return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType)
}
}
// nativeEndian is the byte order for the local platform. Used to send back and
// forth Tensors with the C API. We test for endianness at runtime because
// some architectures can be booted into different endian modes.

View File

@ -15,6 +15,7 @@
package tensorflow
import (
"bytes"
"reflect"
"testing"
)
@ -28,7 +29,6 @@ func TestNewTensor(t *testing.T) {
{nil, int16(5)},
{nil, int32(5)},
{nil, int64(5)},
{nil, int64(5)},
{nil, uint8(5)},
{nil, uint16(5)},
{nil, float32(5)},
@ -103,6 +103,114 @@ func TestNewTensor(t *testing.T) {
}
}
func TestTensorSerialization(t *testing.T) {
var tests = []interface{}{
int8(5),
int16(5),
int32(5),
int64(5),
uint8(5),
uint16(5),
float32(5),
float64(5),
complex(float32(5), float32(6)),
complex(float64(5), float64(6)),
[]float64{1},
[][]float32{{1, 2}, {3, 4}, {5, 6}},
[][][]int8{
{{1, 2}, {3, 4}, {5, 6}},
{{7, 8}, {9, 10}, {11, 12}},
{{0, -1}, {-2, -3}, {-4, -5}},
{{-6, -7}, {-8, -9}, {-10, -11}},
},
}
for _, v := range tests {
t1, err := NewTensor(v)
if err != nil {
t.Errorf("(%v): %v", v, err)
continue
}
buf := new(bytes.Buffer)
n, err := t1.WriteContentsTo(buf)
if err != nil {
t.Errorf("(%v): %v", v, err)
continue
}
if n != int64(buf.Len()) {
t.Errorf("(%v): WriteContentsTo said it wrote %v bytes, but wrote %v", v, n, buf.Len())
}
t2, err := ReadTensor(t1.DataType(), t1.Shape(), buf)
if err != nil {
t.Errorf("(%v): %v", v, err)
continue
}
if buf.Len() != 0 {
t.Errorf("(%v): %v bytes written by WriteContentsTo not read by ReadTensor", v, buf.Len())
}
if got, want := t2.DataType(), t1.DataType(); got != want {
t.Errorf("(%v): Got %v, want %v", v, got, want)
}
if got, want := t2.Shape(), t1.Shape(); !reflect.DeepEqual(got, want) {
t.Errorf("(%v): Got %v, want %v", v, got, want)
}
if got, want := t2.Value(), v; !reflect.DeepEqual(got, want) {
t.Errorf("(%v): Got %v, want %v", v, got, want)
}
}
}
func TestReadTensorDoesNotReadBeyondContent(t *testing.T) {
t1, _ := NewTensor(int8(7))
t2, _ := NewTensor(float32(2.718))
buf := new(bytes.Buffer)
if _, err := t1.WriteContentsTo(buf); err != nil {
t.Fatal(err)
}
if _, err := t2.WriteContentsTo(buf); err != nil {
t.Fatal(err)
}
t3, err := ReadTensor(t1.DataType(), t1.Shape(), buf)
if err != nil {
t.Fatal(err)
}
t4, err := ReadTensor(t2.DataType(), t2.Shape(), buf)
if err != nil {
t.Fatal(err)
}
if v, ok := t3.Value().(int8); !ok || v != 7 {
t.Errorf("Got (%v (%T), %v), want (7 (int8), true)", v, v, ok)
}
if v, ok := t4.Value().(float32); !ok || v != 2.718 {
t.Errorf("Got (%v (%T), %v), want (2.718 (float32), true)", v, v, ok)
}
}
func TestTensorSerializationErrors(t *testing.T) {
// String tensors cannot be serialized
t1, err := NewTensor("abcd")
if err != nil {
t.Fatal(err)
}
buf := new(bytes.Buffer)
if n, err := t1.WriteContentsTo(buf); n != 0 || err == nil || buf.Len() != 0 {
t.Errorf("Got (%v, %v, %v) want (0, <non-nil>, 0)", n, err, buf.Len())
}
// Should fail to read a truncated value.
if t1, err = NewTensor(int8(8)); err != nil {
t.Fatal(err)
}
n, err := t1.WriteContentsTo(buf)
if err != nil {
t.Fatal(err)
}
r := bytes.NewReader(buf.Bytes()[:n-1])
if _, err = ReadTensor(t1.DataType(), t1.Shape(), r); err == nil {
t.Error("ReadTensor should have failed if the tensor content was truncated")
}
}
func benchmarkNewTensor(b *testing.B, v interface{}) {
for i := 0; i < b.N; i++ {
if t, err := NewTensor(v); err != nil || t == nil {