diff --git a/tensorflow/go/tensor_handle.go b/tensorflow/go/tensor_handle.go index befc1c43ba1..3b06773dd16 100644 --- a/tensorflow/go/tensor_handle.go +++ b/tensorflow/go/tensor_handle.go @@ -59,6 +59,13 @@ func (th *TensorHandle) finalizer() { C.TFE_DeleteTensorHandle(th.c) } +// newTensorHandleFromC takes ownership of c and returns the owning TensorHandle. +func newTensorHandleFromC(c *C.TFE_TensorHandle) *TensorHandle { + th := &TensorHandle{c: c} + runtime.SetFinalizer(th, (*TensorHandle).finalizer) + return th +} + // DataType returns the TensorHandle's datatype. func (th *TensorHandle) DataType() DataType { return DataType(C.TFE_TensorHandleDataType(th.c)) @@ -150,5 +157,5 @@ func (th *TensorHandle) CopyToDevice(c *Context, deviceName string) (*TensorHand if err := status.Err(); err != nil { return nil, err } - return &TensorHandle{c: newTh}, nil + return newTensorHandleFromC(newTh), nil }