diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index 48909ffe39e..fc0716bb70f 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -18,6 +18,7 @@ package tensorflow // #include // #include "tensorflow/c/c_api.h" +// #include "tensorflow/c/c_api_experimental.h" import "C" import ( @@ -315,6 +316,11 @@ type SessionOptions struct { // Config is a binary-serialized representation of the // tensorflow.ConfigProto protocol message // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto). + // You can populate this field in three ways. You can use the zero value + // of this field, which configures the session with a default set of + // options. Or you create a `Config` struct with your options and call that + // struct's `Bytes()` method. Or you can generate a byte string outside of + // Go and paste that string into your Go program as a string literal. Config []byte } @@ -349,6 +355,81 @@ func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) }, nil } +// JitLevel represents the level of optimization that the XLA compiler +// performs during just-in-time compilation. +type JitLevel int + +const ( + // JitDefault is the default setting for this version of TensorFlow + // and corresponds to the "DEFAULT" level in the Python API. + // Currently the default is `JitOff`, but it will change to `JitOn` in a + // future version. + JitDefault JitLevel = 0 + // JitOff disables just-in-time compilation and will continue to do so + // even after JIT compilation is enabled by default. + JitOff JitLevel = -1 + // JitOn enables just-in-time compilation. It is a synonym for the + // "ON_1" optimization level in the Python API. + JitOn JitLevel = 1 +) + +// Config represents session parameters as encoded in the tensorflow.ConfigProto +// protocol buffer message. +type Config struct { + // GlobalJitLevel controls the degree of optimization that the XLA just-in-time + // compiler will perform. The default is currently "off", but it is expected + // to change to "on" in a future version of TensorFlow. + GlobalJitLevel JitLevel + + // AllowGPUMemoryGrowth controls whether the TensorFlow memory allocator + // pre-allocates the entire specified GPU memory region or instead starts + // with a small block of GPU memory and grows its memory usage as needed. + AllowGPUMemoryGrowth bool + + // NumCPUs is the maximum number of CPU devices that the session will use. + // A value of 0 means "let the system pick an appropriate number" + NumCPUs int + + // This struct only exposes the session options available via TensorFlow's + // experimental C API function `TF_CreateConfig()`. + // TODO(frreiss): Add additional options here as more session options are + // exposed via the C API. +} + +// Bytes generates a serialized ConfigOptions protobuf for use in the `Config` +// field of a `SessionOptions` struct. +func (c *Config) Bytes() []byte { + // The C API expects an unsigned char that is 0 if XLA compilation is off and + // nonzero otherwise. + // There is currently no way in the C API to specify "use TensorFlow's default + // JIT level". The translation logic here ensures that the zero value of + // c.GlobalJitLevel means the same as the default value of + // OptimizerOptions.global_jit_level in the Python API. + enableXLACompilationAsChar := C.uchar(0) + switch c.GlobalJitLevel { + case JitDefault: + // TODO(frreiss): When the semantics of GlobalJitLevel.DEFAULT change to + // "on", uncomment the following line. + // enableXLACompilationAsChar = C.uchar(1) + case JitOn: + enableXLACompilationAsChar = C.uchar(1) + } + gpuMemoryAllowGrowthAsChar := C.uchar(0) + if c.AllowGPUMemoryGrowth { + gpuMemoryAllowGrowthAsChar = 1 + } + // The C API doesn't currently have a way to say "let the system pick how many + // CPUs to use," so detect the number of CPUs here. + numCPUDevicesAsUint := C.uint(runtime.NumCPU()) + if c.NumCPUs > 0 { + numCPUDevicesAsUint = C.uint(c.NumCPUs) + } + buf := C.TF_CreateConfig(enableXLACompilationAsChar, gpuMemoryAllowGrowthAsChar, numCPUDevicesAsUint) + defer C.TF_DeleteBuffer(buf) + // Copy out of C memory. + return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length)) +} + // cRunArgs translates the arguments to Session.Run and PartialRun.Run into // values suitable for C library calls. type cRunArgs struct { diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index c9bda001671..510108de1ae 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -250,27 +250,15 @@ func ExamplePartialRun() { } func TestSessionConfig(t *testing.T) { - // Exercise SessionOptions. - // Arguably, a better API would be for SessionOptions.Config to be the - // type generated by the protocol buffer compiler. But for now, the - // tensorflow package continues to be independent of protocol buffers - // and this test exercises the option since the implementation has a - // nuanced conversion to C types. - // - // Till then, the []byte form of Config here was generated using a toy - // tensorflow Python program: - /* - import tensorflow - c = tensorflow.ConfigProto() - c.intra_op_parallelism_threads = 1 - print c.SerializeToString() - */ + // Exercise SessionOptions and Config structs graph := NewGraph() c, err := Const(graph, "Const", int32(14)) if err != nil { t.Fatal(err) } - opts := SessionOptions{Config: []byte("(\x01")} + // Use the zero values for Config.GlobalJitLevel and NumCPUs + config := Config{AllowGPUMemoryGrowth: true} + opts := SessionOptions{Config: config.Bytes()} s, err := NewSession(graph, &opts) if err != nil { t.Fatal(err)