Add Go wrapper around Eager C API's TensorHandle.

This is part of a series of changes to provide a thin Go wrapper around the Eager C API.

PiperOrigin-RevId: 225852836
This commit is contained in:
James Keeling 2018-12-17 10:50:50 -08:00 committed by TensorFlower Gardener
parent b8fa200095
commit 24cad5cdc5
2 changed files with 281 additions and 0 deletions

View File

@ -0,0 +1,154 @@
/*
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package tensorflow
// #include <stdlib.h>
// #include "tensorflow/c/c_api.h"
// #include "tensorflow/c/eager/c_api.h"
import "C"
import (
"runtime"
"unsafe"
)
// TensorHandle is a handle to a tensor on a device.
//
// A Tensor referenced by a TensorHandle may be on any device, whereas a Tensor
// always resides in the host CPU's memory.
//
// A Tensor referenced by a TensorHandle may not have been computed yet. For
// example, a TensorHandle might reference the output of an operation that has
// not finished executing. Because of this, various methods, such as Shape() may
// block until the tensor has been instantiated.
//
// This allows multiple operations to be performed on tensors on a device
// (e.g. a GPU) without sending these values back to the host CPU in between
// every operation.
type TensorHandle struct {
c *C.TFE_TensorHandle
}
// NewTensorHandle creates a new tensor handle from a tensor.
func NewTensorHandle(t *Tensor) (*TensorHandle, error) {
status := newStatus()
cHandle := C.TFE_NewTensorHandle(t.c, status.c)
if err := status.Err(); err != nil {
return nil, err
}
th := &TensorHandle{c: cHandle}
runtime.SetFinalizer(th, (*TensorHandle).finalizer)
return th, nil
}
func (th *TensorHandle) finalizer() {
C.TFE_DeleteTensorHandle(th.c)
}
// DataType returns the TensorHandle's datatype.
func (th *TensorHandle) DataType() DataType {
return DataType(C.TFE_TensorHandleDataType(th.c))
}
// Shape returns the shape of the Tensor referenced by th.
func (th *TensorHandle) Shape() ([]int64, error) {
n, err := th.numDims()
if err != nil {
return nil, err
}
r := make([]int64, n)
for i := 0; i < n; i++ {
if r[i], err = th.dim(i); err != nil {
return nil, err
}
}
return r, nil
}
// numDims returns the number of dimensions of the TensorHandle. It blocks
// until the operation that produces the handle has completed.
func (th *TensorHandle) numDims() (int, error) {
status := newStatus()
n := int(C.TFE_TensorHandleNumDims(th.c, status.c))
return n, status.Err()
}
// dim returns the size of the index'th dimension of the TensorHandle. It
// blocks until the operation that produces the handle has completed.
func (th *TensorHandle) dim(index int) (int64, error) {
status := newStatus()
n := int64(C.TFE_TensorHandleDim(th.c, C.int(index), status.c))
if err := status.Err(); err != nil {
return 0, err
}
return n, nil
}
// DeviceName returns the name of the device of the operation that produced the
// TensorHandle. If the handle was produced by a copy, it returns the
// destination device of the copy. Note that returned device name is not always
// the device holding the tensor handle's memory. If you want the latter, use
// BackingDeviceName. This function will block till the operation that produces
// th has completed.
func (th *TensorHandle) DeviceName() (string, error) {
status := newStatus()
name := C.TFE_TensorHandleDeviceName(th.c, status.c)
if err := status.Err(); err != nil {
return "", err
}
return C.GoString(name), nil
}
// BackingDeviceName returns the name of the device in whose memory the tensor
// handle resides. This function will block till the operation that produces
// `h` has completed.
func (th *TensorHandle) BackingDeviceName() (string, error) {
status := newStatus()
name := C.TFE_TensorHandleBackingDeviceName(th.c, status.c)
if err := status.Err(); err != nil {
return "", err
}
return C.GoString(name), nil
}
// ToTensor returns the Tensor referenced by th. It may block if this tensor is
// not yet computed.
func (th *TensorHandle) ToTensor() (*Tensor, error) {
status := newStatus()
cTensor := C.TFE_TensorHandleResolve(th.c, status.c)
if err := status.Err(); err != nil {
return nil, err
}
return newTensorFromC(cTensor), nil
}
// CopyToDevice creates a new TensorHandle with the same contents as this
// TensorHandle but placed in the memory of the device 'deviceName'. If source
// and destination are the same device, then this creates a new handle that
// shares the underlying buffer. Otherwise, it currently requires at least one
// of the source or destination devices to be CPU (i.e., for the source or
// destination tensor to be placed in host memory).
func (th *TensorHandle) CopyToDevice(c *Context, deviceName string) (*TensorHandle, error) {
status := newStatus()
n := C.CString(deviceName)
newTh := C.TFE_TensorHandleCopyToDevice(th.c, c.c, n, status.c)
C.free(unsafe.Pointer(n))
if err := status.Err(); err != nil {
return nil, err
}
return &TensorHandle{c: newTh}, nil
}

View File

@ -0,0 +1,127 @@
/*
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package tensorflow
import (
"reflect"
"strings"
"testing"
)
func TestNewTensorHandle(t *testing.T) {
vals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
tensor, err := NewTensor(vals)
if err != nil {
t.Fatal(err)
}
if _, err = NewTensorHandle(tensor); err != nil {
t.Fatal(err)
}
}
func TestTensorHandleDataType(t *testing.T) {
vals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
tensor, err := NewTensor(vals)
if err != nil {
t.Fatal(err)
}
th, err := NewTensorHandle(tensor)
if err != nil {
t.Fatal(err)
}
if got, want := th.DataType(), Float; got != want {
t.Errorf("Got %v, want %v", got, want)
}
}
func TestTensorHandleShape(t *testing.T) {
vals := [][]float32{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}
tensor, err := NewTensor(vals)
if err != nil {
t.Fatal(err)
}
th, err := NewTensorHandle(tensor)
if err != nil {
t.Fatal(err)
}
got, err := th.Shape()
if err != nil {
t.Fatal(err)
}
if want := []int64{2, 3}; !reflect.DeepEqual(got, want) {
t.Errorf("Got %#v, want %#v", got, want)
}
}
func TestTensorHandleDeviceName(t *testing.T) {
vals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
tensor, err := NewTensor(vals)
if err != nil {
t.Fatal(err)
}
th, err := NewTensorHandle(tensor)
if err != nil {
t.Fatal(err)
}
d, err := th.DeviceName()
if err != nil {
t.Fatal(err)
}
if !strings.Contains(d, "CPU") {
t.Errorf("DeviceName() did not return a CPU device; got: %s", d)
}
}
func TestTensorHandleBackingDeviceName(t *testing.T) {
vals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
tensor, err := NewTensor(vals)
if err != nil {
t.Fatal(err)
}
th, err := NewTensorHandle(tensor)
if err != nil {
t.Fatal(err)
}
d, err := th.BackingDeviceName()
if err != nil {
t.Fatal(err)
}
if !strings.Contains(d, "CPU") {
t.Errorf("BackingDeviceName() did not return a CPU device; got: %s", d)
}
}
func TestTensorHandleToTensor(t *testing.T) {
initialVals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
initialTensor, err := NewTensor(initialVals)
if err != nil {
t.Fatal(err)
}
th, err := NewTensorHandle(initialTensor)
if err != nil {
t.Fatal(err)
}
tensor, err := th.ToTensor()
if v := tensor.Value().([][]float32); !reflect.DeepEqual(v, initialVals) {
t.Errorf("Got %#v, want %#v", v, initialVals)
}
}