Merge pull request #42413 from bioothod:export_golang_functions_master
PiperOrigin-RevId: 327913260 Change-Id: I53c662d45c68eb56e0b5002afd5a48119b240bd6
This commit is contained in:
commit
fcdb58d185
@ -495,3 +495,34 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type LibraryHandler struct {
|
||||
cptr *C.TF_Library
|
||||
}
|
||||
|
||||
// Load library content into current context, useful to load ops implementation into non-monolitic TF build. Returns LibraryHandler or nil and error
|
||||
func LoadLibrary(path string) (*LibraryHandler, error) {
|
||||
status := newStatus()
|
||||
|
||||
cpath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cpath))
|
||||
cptr := C.TF_LoadLibrary(cpath, status.c)
|
||||
if cptr == nil || status.Code() != C.TF_OK {
|
||||
return nil, fmt.Errorf("could not load library %s: code: %d, error: %s", path, status.Code(), status.String())
|
||||
}
|
||||
|
||||
lh := &LibraryHandler{
|
||||
cptr: cptr,
|
||||
}
|
||||
|
||||
runtime.SetFinalizer(lh, (*LibraryHandler).free)
|
||||
return lh, nil
|
||||
}
|
||||
|
||||
func (lh *LibraryHandler) free() {
|
||||
if lh == nil || lh.cptr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
C.TF_DeleteLibraryHandle(lh.cptr)
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
nflattened := numElements(shape)
|
||||
nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
|
||||
nbytes := TypeOf(dataType, nil).Size() * uintptr(nflattened)
|
||||
if dataType == String {
|
||||
nbytes = uintptr(nflattened) * C.sizeof_TF_TString
|
||||
}
|
||||
@ -168,7 +168,7 @@ func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error)
|
||||
if err := isTensorSerializable(dataType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
|
||||
nbytes := TypeOf(dataType, nil).Size() * uintptr(numElements(shape))
|
||||
var shapePtr *C.int64_t
|
||||
if len(shape) > 0 {
|
||||
shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
|
||||
@ -207,6 +207,28 @@ func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) }
|
||||
// Shape returns the shape of the Tensor.
|
||||
func (t *Tensor) Shape() []int64 { return t.shape }
|
||||
|
||||
// Reshape updates tensor's shape in place if this is possible or returns an error otherwise.
|
||||
func (t *Tensor) Reshape(new_shape []int64) error {
|
||||
old_shape_size := numElements(t.shape)
|
||||
new_shape_size := numElements(new_shape)
|
||||
|
||||
if old_shape_size != new_shape_size {
|
||||
return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, old_shape_size, new_shape, new_shape_size)
|
||||
}
|
||||
|
||||
if len(new_shape) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var shapePtr *C.int64_t
|
||||
shapePtr = (*C.int64_t)(unsafe.Pointer(&new_shape[0]))
|
||||
|
||||
status := newStatus()
|
||||
C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(new_shape)), status.c)
|
||||
|
||||
return status.Err()
|
||||
}
|
||||
|
||||
// Value converts the Tensor to a Go value. For now, not all Tensor types are
|
||||
// supported, and this function may panic if it encounters an unsupported
|
||||
// DataType.
|
||||
@ -407,8 +429,8 @@ func typeForDataType(dt DataType) reflect.Type {
|
||||
panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
|
||||
}
|
||||
|
||||
// typeOf converts from a DataType and Shape to the equivalent Go type.
|
||||
func typeOf(dt DataType, shape []int64) reflect.Type {
|
||||
// TypeOf converts from a DataType and Shape to the equivalent Go type.
|
||||
func TypeOf(dt DataType, shape []int64) reflect.Type {
|
||||
ret := typeForDataType(dt)
|
||||
for range shape {
|
||||
ret = reflect.SliceOf(ret)
|
||||
|
Loading…
Reference in New Issue
Block a user