Go: API functions to serialize/deserialize Tensors.
The current implementation has limitations, in particular it does not support string tensors. But the same API will be sufficient when that limitation has been addressed. See also #6003. The API added here can be used to fill in the tensor_contents field of a TensorProto protocol buffer. Change: 140760816
This commit is contained in:
parent
c4d507ac75
commit
0ba830748d
@ -108,6 +108,35 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// ReadTensor constructs a Tensor with the provided type and shape from the
|
||||
// serialized tensor contents in r.
|
||||
//
|
||||
// See also WriteContentsTo.
|
||||
func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) {
|
||||
if err := isTensorSerializable(dataType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
|
||||
var shapePtr *C.int64_t
|
||||
if len(shape) > 0 {
|
||||
shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
|
||||
}
|
||||
t := &Tensor{
|
||||
c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
|
||||
shape: shape,
|
||||
}
|
||||
runtime.SetFinalizer(t, (*Tensor).finalize)
|
||||
raw := tensorData(t.c)
|
||||
n, err := r.Read(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if uintptr(n) != nbytes {
|
||||
return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// newTensorFromC takes ownership of c and returns the owning Tensor.
|
||||
func newTensorFromC(c *C.TF_Tensor) *Tensor {
|
||||
var shape []int64
|
||||
@ -156,6 +185,21 @@ func (t *Tensor) Value() interface{} {
|
||||
return reflect.Indirect(val).Interface()
|
||||
}
|
||||
|
||||
// WriteContentsTo writes the serialized contents of t to w.
|
||||
//
|
||||
// Returns the number of bytes written. See ReadTensor for
|
||||
// reconstructing a Tensor from the serialized form.
|
||||
//
|
||||
// WARNING: WriteContentsTo is not comprehensive and will fail
|
||||
// if t.DataType() is non-numeric (e.g., String). See
|
||||
// https://github.com/tensorflow/tensorflow/issues/6003.
|
||||
func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) {
|
||||
if err := isTensorSerializable(t.DataType()); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return io.Copy(w, bytes.NewReader(tensorData(t.c)))
|
||||
}
|
||||
|
||||
func tensorData(c *C.TF_Tensor) []byte {
|
||||
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
cbytes := C.TF_TensorData(c)
|
||||
@ -385,6 +429,23 @@ func bug(format string, args ...interface{}) error {
|
||||
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func isTensorSerializable(dataType DataType) error {
|
||||
// For numeric types, the serialized Tensor matches the in-memory
|
||||
// representation. See the implementation of Tensor::AsProtoContent in
|
||||
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc
|
||||
//
|
||||
// The more appropriate way to be in sync with Tensor::AsProtoContent
|
||||
// would be to have the TensorFlow C library export functions for
|
||||
// serialization and deserialization of Tensors. Till then capitalize
|
||||
// on knowledge of the implementation for numeric types.
|
||||
switch dataType {
|
||||
case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType)
|
||||
}
|
||||
}
|
||||
|
||||
// nativeEndian is the byte order for the local platform. Used to send back and
|
||||
// forth Tensors with the C API. We test for endianness at runtime because
|
||||
// some architectures can be booted into different endian modes.
|
||||
|
@ -15,6 +15,7 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
@ -28,7 +29,6 @@ func TestNewTensor(t *testing.T) {
|
||||
{nil, int16(5)},
|
||||
{nil, int32(5)},
|
||||
{nil, int64(5)},
|
||||
{nil, int64(5)},
|
||||
{nil, uint8(5)},
|
||||
{nil, uint16(5)},
|
||||
{nil, float32(5)},
|
||||
@ -103,6 +103,114 @@ func TestNewTensor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorSerialization(t *testing.T) {
|
||||
var tests = []interface{}{
|
||||
int8(5),
|
||||
int16(5),
|
||||
int32(5),
|
||||
int64(5),
|
||||
uint8(5),
|
||||
uint16(5),
|
||||
float32(5),
|
||||
float64(5),
|
||||
complex(float32(5), float32(6)),
|
||||
complex(float64(5), float64(6)),
|
||||
[]float64{1},
|
||||
[][]float32{{1, 2}, {3, 4}, {5, 6}},
|
||||
[][][]int8{
|
||||
{{1, 2}, {3, 4}, {5, 6}},
|
||||
{{7, 8}, {9, 10}, {11, 12}},
|
||||
{{0, -1}, {-2, -3}, {-4, -5}},
|
||||
{{-6, -7}, {-8, -9}, {-10, -11}},
|
||||
},
|
||||
}
|
||||
for _, v := range tests {
|
||||
t1, err := NewTensor(v)
|
||||
if err != nil {
|
||||
t.Errorf("(%v): %v", v, err)
|
||||
continue
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
n, err := t1.WriteContentsTo(buf)
|
||||
if err != nil {
|
||||
t.Errorf("(%v): %v", v, err)
|
||||
continue
|
||||
}
|
||||
if n != int64(buf.Len()) {
|
||||
t.Errorf("(%v): WriteContentsTo said it wrote %v bytes, but wrote %v", v, n, buf.Len())
|
||||
}
|
||||
t2, err := ReadTensor(t1.DataType(), t1.Shape(), buf)
|
||||
if err != nil {
|
||||
t.Errorf("(%v): %v", v, err)
|
||||
continue
|
||||
}
|
||||
if buf.Len() != 0 {
|
||||
t.Errorf("(%v): %v bytes written by WriteContentsTo not read by ReadTensor", v, buf.Len())
|
||||
}
|
||||
if got, want := t2.DataType(), t1.DataType(); got != want {
|
||||
t.Errorf("(%v): Got %v, want %v", v, got, want)
|
||||
}
|
||||
if got, want := t2.Shape(), t1.Shape(); !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("(%v): Got %v, want %v", v, got, want)
|
||||
}
|
||||
if got, want := t2.Value(), v; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("(%v): Got %v, want %v", v, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTensorDoesNotReadBeyondContent(t *testing.T) {
|
||||
t1, _ := NewTensor(int8(7))
|
||||
t2, _ := NewTensor(float32(2.718))
|
||||
buf := new(bytes.Buffer)
|
||||
if _, err := t1.WriteContentsTo(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := t2.WriteContentsTo(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t3, err := ReadTensor(t1.DataType(), t1.Shape(), buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t4, err := ReadTensor(t2.DataType(), t2.Shape(), buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v, ok := t3.Value().(int8); !ok || v != 7 {
|
||||
t.Errorf("Got (%v (%T), %v), want (7 (int8), true)", v, v, ok)
|
||||
}
|
||||
if v, ok := t4.Value().(float32); !ok || v != 2.718 {
|
||||
t.Errorf("Got (%v (%T), %v), want (2.718 (float32), true)", v, v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorSerializationErrors(t *testing.T) {
|
||||
// String tensors cannot be serialized
|
||||
t1, err := NewTensor("abcd")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
if n, err := t1.WriteContentsTo(buf); n != 0 || err == nil || buf.Len() != 0 {
|
||||
t.Errorf("Got (%v, %v, %v) want (0, <non-nil>, 0)", n, err, buf.Len())
|
||||
}
|
||||
// Should fail to read a truncated value.
|
||||
if t1, err = NewTensor(int8(8)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
n, err := t1.WriteContentsTo(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := bytes.NewReader(buf.Bytes()[:n-1])
|
||||
if _, err = ReadTensor(t1.DataType(), t1.Shape(), r); err == nil {
|
||||
t.Error("ReadTensor should have failed if the tensor content was truncated")
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkNewTensor(b *testing.B, v interface{}) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if t, err := NewTensor(v); err != nil || t == nil {
|
||||
|
Loading…
Reference in New Issue
Block a user