Add Go wrapper around Eager C API's TensorHandle.
This is part of a series of changes to provide a thin Go wrapper around the Eager C API. PiperOrigin-RevId: 225852836
This commit is contained in:
parent
b8fa200095
commit
24cad5cdc5
154
tensorflow/go/tensor_handle.go
Normal file
154
tensorflow/go/tensor_handle.go
Normal file
@ -0,0 +1,154 @@
|
||||
/*
|
||||
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package tensorflow
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "tensorflow/c/c_api.h"
|
||||
// #include "tensorflow/c/eager/c_api.h"
|
||||
import "C"
|
||||
import (
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// TensorHandle is a handle to a tensor on a device.
|
||||
//
|
||||
// A Tensor referenced by a TensorHandle may be on any device, whereas a Tensor
|
||||
// always resides in the host CPU's memory.
|
||||
//
|
||||
// A Tensor referenced by a TensorHandle may not have been computed yet. For
|
||||
// example, a TensorHandle might reference the output of an operation that has
|
||||
// not finished executing. Because of this, various methods, such as Shape() may
|
||||
// block until the tensor has been instantiated.
|
||||
//
|
||||
// This allows multiple operations to be performed on tensors on a device
|
||||
// (e.g. a GPU) without sending these values back to the host CPU in between
|
||||
// every operation.
|
||||
type TensorHandle struct {
|
||||
c *C.TFE_TensorHandle
|
||||
}
|
||||
|
||||
// NewTensorHandle creates a new tensor handle from a tensor.
|
||||
func NewTensorHandle(t *Tensor) (*TensorHandle, error) {
|
||||
status := newStatus()
|
||||
cHandle := C.TFE_NewTensorHandle(t.c, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
th := &TensorHandle{c: cHandle}
|
||||
runtime.SetFinalizer(th, (*TensorHandle).finalizer)
|
||||
return th, nil
|
||||
}
|
||||
|
||||
func (th *TensorHandle) finalizer() {
|
||||
C.TFE_DeleteTensorHandle(th.c)
|
||||
}
|
||||
|
||||
// DataType returns the TensorHandle's datatype.
|
||||
func (th *TensorHandle) DataType() DataType {
|
||||
return DataType(C.TFE_TensorHandleDataType(th.c))
|
||||
}
|
||||
|
||||
// Shape returns the shape of the Tensor referenced by th.
|
||||
func (th *TensorHandle) Shape() ([]int64, error) {
|
||||
n, err := th.numDims()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := make([]int64, n)
|
||||
for i := 0; i < n; i++ {
|
||||
if r[i], err = th.dim(i); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// numDims returns the number of dimensions of the TensorHandle. It blocks
|
||||
// until the operation that produces the handle has completed.
|
||||
func (th *TensorHandle) numDims() (int, error) {
|
||||
status := newStatus()
|
||||
n := int(C.TFE_TensorHandleNumDims(th.c, status.c))
|
||||
return n, status.Err()
|
||||
}
|
||||
|
||||
// dim returns the size of the index'th dimension of the TensorHandle. It
|
||||
// blocks until the operation that produces the handle has completed.
|
||||
func (th *TensorHandle) dim(index int) (int64, error) {
|
||||
status := newStatus()
|
||||
n := int64(C.TFE_TensorHandleDim(th.c, C.int(index), status.c))
|
||||
if err := status.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// DeviceName returns the name of the device of the operation that produced the
|
||||
// TensorHandle. If the handle was produced by a copy, it returns the
|
||||
// destination device of the copy. Note that returned device name is not always
|
||||
// the device holding the tensor handle's memory. If you want the latter, use
|
||||
// BackingDeviceName. This function will block till the operation that produces
|
||||
// th has completed.
|
||||
func (th *TensorHandle) DeviceName() (string, error) {
|
||||
status := newStatus()
|
||||
name := C.TFE_TensorHandleDeviceName(th.c, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return C.GoString(name), nil
|
||||
}
|
||||
|
||||
// BackingDeviceName returns the name of the device in whose memory the tensor
|
||||
// handle resides. This function will block till the operation that produces
|
||||
// `h` has completed.
|
||||
func (th *TensorHandle) BackingDeviceName() (string, error) {
|
||||
status := newStatus()
|
||||
name := C.TFE_TensorHandleBackingDeviceName(th.c, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return C.GoString(name), nil
|
||||
}
|
||||
|
||||
// ToTensor returns the Tensor referenced by th. It may block if this tensor is
|
||||
// not yet computed.
|
||||
func (th *TensorHandle) ToTensor() (*Tensor, error) {
|
||||
status := newStatus()
|
||||
cTensor := C.TFE_TensorHandleResolve(th.c, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newTensorFromC(cTensor), nil
|
||||
}
|
||||
|
||||
// CopyToDevice creates a new TensorHandle with the same contents as this
|
||||
// TensorHandle but placed in the memory of the device 'deviceName'. If source
|
||||
// and destination are the same device, then this creates a new handle that
|
||||
// shares the underlying buffer. Otherwise, it currently requires at least one
|
||||
// of the source or destination devices to be CPU (i.e., for the source or
|
||||
// destination tensor to be placed in host memory).
|
||||
func (th *TensorHandle) CopyToDevice(c *Context, deviceName string) (*TensorHandle, error) {
|
||||
status := newStatus()
|
||||
n := C.CString(deviceName)
|
||||
newTh := C.TFE_TensorHandleCopyToDevice(th.c, c.c, n, status.c)
|
||||
C.free(unsafe.Pointer(n))
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TensorHandle{c: newTh}, nil
|
||||
}
|
127
tensorflow/go/tensor_handle_test.go
Normal file
127
tensorflow/go/tensor_handle_test.go
Normal file
@ -0,0 +1,127 @@
|
||||
/*
|
||||
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewTensorHandle(t *testing.T) {
|
||||
vals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
|
||||
tensor, err := NewTensor(vals)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err = NewTensorHandle(tensor); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorHandleDataType(t *testing.T) {
|
||||
vals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
|
||||
tensor, err := NewTensor(vals)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
th, err := NewTensorHandle(tensor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := th.DataType(), Float; got != want {
|
||||
t.Errorf("Got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorHandleShape(t *testing.T) {
|
||||
vals := [][]float32{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}
|
||||
tensor, err := NewTensor(vals)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
th, err := NewTensorHandle(tensor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := th.Shape()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := []int64{2, 3}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Got %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorHandleDeviceName(t *testing.T) {
|
||||
vals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
|
||||
tensor, err := NewTensor(vals)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
th, err := NewTensorHandle(tensor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
d, err := th.DeviceName()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(d, "CPU") {
|
||||
t.Errorf("DeviceName() did not return a CPU device; got: %s", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorHandleBackingDeviceName(t *testing.T) {
|
||||
vals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
|
||||
tensor, err := NewTensor(vals)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
th, err := NewTensorHandle(tensor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
d, err := th.BackingDeviceName()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(d, "CPU") {
|
||||
t.Errorf("BackingDeviceName() did not return a CPU device; got: %s", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorHandleToTensor(t *testing.T) {
|
||||
initialVals := [][]float32{{1.0, 2.0}, {3.0, 4.0}}
|
||||
initialTensor, err := NewTensor(initialVals)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
th, err := NewTensorHandle(initialTensor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tensor, err := th.ToTensor()
|
||||
if v := tensor.Value().([][]float32); !reflect.DeepEqual(v, initialVals) {
|
||||
t.Errorf("Got %#v, want %#v", v, initialVals)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user