Merge pull request #42413 from bioothod:export_golang_functions_master

PiperOrigin-RevId: 327913260
Change-Id: I53c662d45c68eb56e0b5002afd5a48119b240bd6
This commit is contained in:
TensorFlower Gardener 2020-08-21 19:53:18 -07:00
commit fcdb58d185
2 changed files with 57 additions and 4 deletions

View File

@ -495,3 +495,34 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
}
return nil
}
type LibraryHandler struct {
cptr *C.TF_Library
}
// Load library content into current context, useful to load ops implementation into non-monolitic TF build. Returns LibraryHandler or nil and error
func LoadLibrary(path string) (*LibraryHandler, error) {
status := newStatus()
cpath := C.CString(path)
defer C.free(unsafe.Pointer(cpath))
cptr := C.TF_LoadLibrary(cpath, status.c)
if cptr == nil || status.Code() != C.TF_OK {
return nil, fmt.Errorf("could not load library %s: code: %d, error: %s", path, status.Code(), status.String())
}
lh := &LibraryHandler{
cptr: cptr,
}
runtime.SetFinalizer(lh, (*LibraryHandler).free)
return lh, nil
}
func (lh *LibraryHandler) free() {
if lh == nil || lh.cptr == nil {
return
}
C.TF_DeleteLibraryHandle(lh.cptr)
}

View File

@ -83,7 +83,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
return nil, err
}
nflattened := numElements(shape)
nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
nbytes := TypeOf(dataType, nil).Size() * uintptr(nflattened)
if dataType == String {
nbytes = uintptr(nflattened) * C.sizeof_TF_TString
}
@ -168,7 +168,7 @@ func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error)
if err := isTensorSerializable(dataType); err != nil {
return nil, err
}
nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
nbytes := TypeOf(dataType, nil).Size() * uintptr(numElements(shape))
var shapePtr *C.int64_t
if len(shape) > 0 {
shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
@ -207,6 +207,28 @@ func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) }
// Shape returns the shape of the Tensor.
func (t *Tensor) Shape() []int64 { return t.shape }
// Reshape updates tensor's shape in place if this is possible or returns an error otherwise.
func (t *Tensor) Reshape(new_shape []int64) error {
old_shape_size := numElements(t.shape)
new_shape_size := numElements(new_shape)
if old_shape_size != new_shape_size {
return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, old_shape_size, new_shape, new_shape_size)
}
if len(new_shape) == 0 {
return nil
}
var shapePtr *C.int64_t
shapePtr = (*C.int64_t)(unsafe.Pointer(&new_shape[0]))
status := newStatus()
C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(new_shape)), status.c)
return status.Err()
}
// Value converts the Tensor to a Go value. For now, not all Tensor types are
// supported, and this function may panic if it encounters an unsupported
// DataType.
@ -407,8 +429,8 @@ func typeForDataType(dt DataType) reflect.Type {
panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
}
// typeOf converts from a DataType and Shape to the equivalent Go type.
func typeOf(dt DataType, shape []int64) reflect.Type {
// TypeOf converts from a DataType and Shape to the equivalent Go type.
func TypeOf(dt DataType, shape []int64) reflect.Type {
ret := typeForDataType(dt)
for range shape {
ret = reflect.SliceOf(ret)