diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index ac28c3ac5bd..60de1e1a29e 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -495,3 +495,34 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu } return nil } + +type LibraryHandler struct { + cptr *C.TF_Library +} + +// Load library content into current context, useful to load ops implementation into non-monolitic TF build. Returns LibraryHandler or nil and error +func LoadLibrary(path string) (*LibraryHandler, error) { + status := newStatus() + + cpath := C.CString(path) + defer C.free(unsafe.Pointer(cpath)) + cptr := C.TF_LoadLibrary(cpath, status.c) + if cptr == nil || status.Code() != C.TF_OK { + return nil, fmt.Errorf("could not load library %s: code: %d, error: %s", path, status.Code(), status.String()) + } + + lh := &LibraryHandler{ + cptr: cptr, + } + + runtime.SetFinalizer(lh, (*LibraryHandler).free) + return lh, nil +} + +func (lh *LibraryHandler) free() { + if lh == nil || lh.cptr == nil { + return + } + + C.TF_DeleteLibraryHandle(lh.cptr) +} diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 9221d35274c..d9036ced325 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -83,7 +83,7 @@ func NewTensor(value interface{}) (*Tensor, error) { return nil, err } nflattened := numElements(shape) - nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened) + nbytes := TypeOf(dataType, nil).Size() * uintptr(nflattened) if dataType == String { nbytes = uintptr(nflattened) * C.sizeof_TF_TString } @@ -168,7 +168,7 @@ func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) if err := isTensorSerializable(dataType); err != nil { return nil, err } - nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape)) + nbytes := TypeOf(dataType, nil).Size() * uintptr(numElements(shape)) var shapePtr *C.int64_t if len(shape) > 0 { shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) @@ -207,6 +207,28 @@ func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) } // Shape returns the shape of the Tensor. func (t *Tensor) Shape() []int64 { return t.shape } +// Reshape updates tensor's shape in place if this is possible or returns an error otherwise. +func (t *Tensor) Reshape(new_shape []int64) error { + old_shape_size := numElements(t.shape) + new_shape_size := numElements(new_shape) + + if old_shape_size != new_shape_size { + return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, old_shape_size, new_shape, new_shape_size) + } + + if len(new_shape) == 0 { + return nil + } + + var shapePtr *C.int64_t + shapePtr = (*C.int64_t)(unsafe.Pointer(&new_shape[0])) + + status := newStatus() + C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(new_shape)), status.c) + + return status.Err() +} + // Value converts the Tensor to a Go value. For now, not all Tensor types are // supported, and this function may panic if it encounters an unsupported // DataType. @@ -407,8 +429,8 @@ func typeForDataType(dt DataType) reflect.Type { panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt)) } -// typeOf converts from a DataType and Shape to the equivalent Go type. -func typeOf(dt DataType, shape []int64) reflect.Type { +// TypeOf converts from a DataType and Shape to the equivalent Go type. +func TypeOf(dt DataType, shape []int64) reflect.Type { ret := typeForDataType(dt) for range shape { ret = reflect.SliceOf(ret)