Add function to create serialized ConfigOptions protos
This commit is contained in:
parent
6108cbd1db
commit
537ba63547
@ -18,6 +18,7 @@ package tensorflow
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "tensorflow/c/c_api.h"
|
||||
// #include "tensorflow/c/c_api_experimental.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
@ -349,6 +350,27 @@ func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewConfigOptions generates a serialized ConfigOptions protobuf for use
|
||||
// in the `Config` field of a `SessionOptions` struct. The function only supports
|
||||
// the session options available via TensorFlow's experimental C API function
|
||||
// `TF_CreateConfig()`. THIS API IS UNSTABLE, and the set of available options
|
||||
// will change in future versions of TensorFlow.
|
||||
func NewConfigOptions(enableXLACompilation bool, gpuMemoryAllowGrowth bool, numCPUDevices uint) []byte {
|
||||
// C API expects unsigned chars
|
||||
enableXLACompilationAsChar := C.uchar(0)
|
||||
if enableXLACompilation {
|
||||
enableXLACompilationAsChar = 1
|
||||
}
|
||||
gpuMemoryAllowGrowthAsChar := C.uchar(0)
|
||||
if gpuMemoryAllowGrowth {
|
||||
gpuMemoryAllowGrowthAsChar = 1
|
||||
}
|
||||
buf := C.TF_CreateConfig(enableXLACompilationAsChar, gpuMemoryAllowGrowthAsChar, C.uint(numCPUDevices))
|
||||
defer C.TF_DeleteBuffer(buf)
|
||||
// Copy out of C memory.
|
||||
return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length))
|
||||
}
|
||||
|
||||
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
|
||||
// values suitable for C library calls.
|
||||
type cRunArgs struct {
|
||||
|
@ -270,7 +270,8 @@ func TestSessionConfig(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
opts := SessionOptions{Config: []byte("(\x01")}
|
||||
config := NewConfigOptions(true, true, 1)
|
||||
opts := SessionOptions{Config: config}
|
||||
s, err := NewSession(graph, &opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
Loading…
Reference in New Issue
Block a user