Go: API functions to serialize/deserialize Tensors.
The current implementation has limitations, in particular it does not support string tensors. But the same API will be sufficient when that limitation has been addressed. See also #6003. The API added here can be used to fill in the tensor_contents field of a TensorProto protocol buffer. Change: 140760816
This commit is contained in:
parent
c4d507ac75
commit
0ba830748d
@ -108,6 +108,35 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadTensor constructs a Tensor with the provided type and shape from the
|
||||||
|
// serialized tensor contents in r.
|
||||||
|
//
|
||||||
|
// See also WriteContentsTo.
|
||||||
|
func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) {
|
||||||
|
if err := isTensorSerializable(dataType); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
|
||||||
|
var shapePtr *C.int64_t
|
||||||
|
if len(shape) > 0 {
|
||||||
|
shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
|
||||||
|
}
|
||||||
|
t := &Tensor{
|
||||||
|
c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
|
||||||
|
shape: shape,
|
||||||
|
}
|
||||||
|
runtime.SetFinalizer(t, (*Tensor).finalize)
|
||||||
|
raw := tensorData(t.c)
|
||||||
|
n, err := r.Read(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if uintptr(n) != nbytes {
|
||||||
|
return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n)
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
// newTensorFromC takes ownership of c and returns the owning Tensor.
|
// newTensorFromC takes ownership of c and returns the owning Tensor.
|
||||||
func newTensorFromC(c *C.TF_Tensor) *Tensor {
|
func newTensorFromC(c *C.TF_Tensor) *Tensor {
|
||||||
var shape []int64
|
var shape []int64
|
||||||
@ -156,6 +185,21 @@ func (t *Tensor) Value() interface{} {
|
|||||||
return reflect.Indirect(val).Interface()
|
return reflect.Indirect(val).Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteContentsTo writes the serialized contents of t to w.
|
||||||
|
//
|
||||||
|
// Returns the number of bytes written. See ReadTensor for
|
||||||
|
// reconstructing a Tensor from the serialized form.
|
||||||
|
//
|
||||||
|
// WARNING: WriteContentsTo is not comprehensive and will fail
|
||||||
|
// if t.DataType() is non-numeric (e.g., String). See
|
||||||
|
// https://github.com/tensorflow/tensorflow/issues/6003.
|
||||||
|
func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) {
|
||||||
|
if err := isTensorSerializable(t.DataType()); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return io.Copy(w, bytes.NewReader(tensorData(t.c)))
|
||||||
|
}
|
||||||
|
|
||||||
func tensorData(c *C.TF_Tensor) []byte {
|
func tensorData(c *C.TF_Tensor) []byte {
|
||||||
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||||
cbytes := C.TF_TensorData(c)
|
cbytes := C.TF_TensorData(c)
|
||||||
@ -385,6 +429,23 @@ func bug(format string, args ...interface{}) error {
|
|||||||
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
|
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isTensorSerializable(dataType DataType) error {
|
||||||
|
// For numeric types, the serialized Tensor matches the in-memory
|
||||||
|
// representation. See the implementation of Tensor::AsProtoContent in
|
||||||
|
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc
|
||||||
|
//
|
||||||
|
// The more appropriate way to be in sync with Tensor::AsProtoContent
|
||||||
|
// would be to have the TensorFlow C library export functions for
|
||||||
|
// serialization and deserialization of Tensors. Till then capitalize
|
||||||
|
// on knowledge of the implementation for numeric types.
|
||||||
|
switch dataType {
|
||||||
|
case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// nativeEndian is the byte order for the local platform. Used to send back and
|
// nativeEndian is the byte order for the local platform. Used to send back and
|
||||||
// forth Tensors with the C API. We test for endianness at runtime because
|
// forth Tensors with the C API. We test for endianness at runtime because
|
||||||
// some architectures can be booted into different endian modes.
|
// some architectures can be booted into different endian modes.
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package tensorflow
|
package tensorflow
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -28,7 +29,6 @@ func TestNewTensor(t *testing.T) {
|
|||||||
{nil, int16(5)},
|
{nil, int16(5)},
|
||||||
{nil, int32(5)},
|
{nil, int32(5)},
|
||||||
{nil, int64(5)},
|
{nil, int64(5)},
|
||||||
{nil, int64(5)},
|
|
||||||
{nil, uint8(5)},
|
{nil, uint8(5)},
|
||||||
{nil, uint16(5)},
|
{nil, uint16(5)},
|
||||||
{nil, float32(5)},
|
{nil, float32(5)},
|
||||||
@ -103,6 +103,114 @@ func TestNewTensor(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTensorSerialization(t *testing.T) {
|
||||||
|
var tests = []interface{}{
|
||||||
|
int8(5),
|
||||||
|
int16(5),
|
||||||
|
int32(5),
|
||||||
|
int64(5),
|
||||||
|
uint8(5),
|
||||||
|
uint16(5),
|
||||||
|
float32(5),
|
||||||
|
float64(5),
|
||||||
|
complex(float32(5), float32(6)),
|
||||||
|
complex(float64(5), float64(6)),
|
||||||
|
[]float64{1},
|
||||||
|
[][]float32{{1, 2}, {3, 4}, {5, 6}},
|
||||||
|
[][][]int8{
|
||||||
|
{{1, 2}, {3, 4}, {5, 6}},
|
||||||
|
{{7, 8}, {9, 10}, {11, 12}},
|
||||||
|
{{0, -1}, {-2, -3}, {-4, -5}},
|
||||||
|
{{-6, -7}, {-8, -9}, {-10, -11}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, v := range tests {
|
||||||
|
t1, err := NewTensor(v)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("(%v): %v", v, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
n, err := t1.WriteContentsTo(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("(%v): %v", v, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n != int64(buf.Len()) {
|
||||||
|
t.Errorf("(%v): WriteContentsTo said it wrote %v bytes, but wrote %v", v, n, buf.Len())
|
||||||
|
}
|
||||||
|
t2, err := ReadTensor(t1.DataType(), t1.Shape(), buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("(%v): %v", v, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if buf.Len() != 0 {
|
||||||
|
t.Errorf("(%v): %v bytes written by WriteContentsTo not read by ReadTensor", v, buf.Len())
|
||||||
|
}
|
||||||
|
if got, want := t2.DataType(), t1.DataType(); got != want {
|
||||||
|
t.Errorf("(%v): Got %v, want %v", v, got, want)
|
||||||
|
}
|
||||||
|
if got, want := t2.Shape(), t1.Shape(); !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("(%v): Got %v, want %v", v, got, want)
|
||||||
|
}
|
||||||
|
if got, want := t2.Value(), v; !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("(%v): Got %v, want %v", v, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadTensorDoesNotReadBeyondContent(t *testing.T) {
|
||||||
|
t1, _ := NewTensor(int8(7))
|
||||||
|
t2, _ := NewTensor(float32(2.718))
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
if _, err := t1.WriteContentsTo(buf); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := t2.WriteContentsTo(buf); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t3, err := ReadTensor(t1.DataType(), t1.Shape(), buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t4, err := ReadTensor(t2.DataType(), t2.Shape(), buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := t3.Value().(int8); !ok || v != 7 {
|
||||||
|
t.Errorf("Got (%v (%T), %v), want (7 (int8), true)", v, v, ok)
|
||||||
|
}
|
||||||
|
if v, ok := t4.Value().(float32); !ok || v != 2.718 {
|
||||||
|
t.Errorf("Got (%v (%T), %v), want (2.718 (float32), true)", v, v, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTensorSerializationErrors(t *testing.T) {
|
||||||
|
// String tensors cannot be serialized
|
||||||
|
t1, err := NewTensor("abcd")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
if n, err := t1.WriteContentsTo(buf); n != 0 || err == nil || buf.Len() != 0 {
|
||||||
|
t.Errorf("Got (%v, %v, %v) want (0, <non-nil>, 0)", n, err, buf.Len())
|
||||||
|
}
|
||||||
|
// Should fail to read a truncated value.
|
||||||
|
if t1, err = NewTensor(int8(8)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
n, err := t1.WriteContentsTo(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
r := bytes.NewReader(buf.Bytes()[:n-1])
|
||||||
|
if _, err = ReadTensor(t1.DataType(), t1.Shape(), r); err == nil {
|
||||||
|
t.Error("ReadTensor should have failed if the tensor content was truncated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func benchmarkNewTensor(b *testing.B, v interface{}) {
|
func benchmarkNewTensor(b *testing.B, v interface{}) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if t, err := NewTensor(v); err != nil || t == nil {
|
if t, err := NewTensor(v); err != nil || t == nil {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user