Add function to create serialized ConfigOptions protos

This commit is contained in:
frreiss 2019-03-13 16:19:23 -07:00
parent 6108cbd1db
commit 537ba63547
2 changed files with 24 additions and 1 deletions

View File

@ -18,6 +18,7 @@ package tensorflow
// #include <stdlib.h>
// #include "tensorflow/c/c_api.h"
// #include "tensorflow/c/c_api_experimental.h"
import "C"
import (
@ -349,6 +350,27 @@ func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error)
}, nil
}
// NewConfigOptions generates a serialized ConfigOptions protobuf for use
// in the `Config` field of a `SessionOptions` struct. The function only supports
// the session options available via TensorFlow's experimental C API function
// `TF_CreateConfig()`. THIS API IS UNSTABLE, and the set of available options
// will change in future versions of TensorFlow.
func NewConfigOptions(enableXLACompilation bool, gpuMemoryAllowGrowth bool, numCPUDevices uint) []byte {
// C API expects unsigned chars
enableXLACompilationAsChar := C.uchar(0)
if enableXLACompilation {
enableXLACompilationAsChar = 1
}
gpuMemoryAllowGrowthAsChar := C.uchar(0)
if gpuMemoryAllowGrowth {
gpuMemoryAllowGrowthAsChar = 1
}
buf := C.TF_CreateConfig(enableXLACompilationAsChar, gpuMemoryAllowGrowthAsChar, C.uint(numCPUDevices))
defer C.TF_DeleteBuffer(buf)
// Copy out of C memory.
return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length))
}
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
// values suitable for C library calls.
type cRunArgs struct {

View File

@ -270,7 +270,8 @@ func TestSessionConfig(t *testing.T) {
if err != nil {
t.Fatal(err)
}
opts := SessionOptions{Config: []byte("(\x01")}
config := NewConfigOptions(true, true, 1)
opts := SessionOptions{Config: config}
s, err := NewSession(graph, &opts)
if err != nil {
t.Fatal(err)