diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index fc0716bb70f..48909ffe39e 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -18,7 +18,6 @@ package tensorflow // #include // #include "tensorflow/c/c_api.h" -// #include "tensorflow/c/c_api_experimental.h" import "C" import ( @@ -316,11 +315,6 @@ type SessionOptions struct { // Config is a binary-serialized representation of the // tensorflow.ConfigProto protocol message // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto). - // You can populate this field in three ways. You can use the zero value - // of this field, which configures the session with a default set of - // options. Or you create a `Config` struct with your options and call that - // struct's `Bytes()` method. Or you can generate a byte string outside of - // Go and paste that string into your Go program as a string literal. Config []byte } @@ -355,81 +349,6 @@ func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) }, nil } -// JitLevel represents the level of optimization that the XLA compiler -// performs during just-in-time compilation. -type JitLevel int - -const ( - // JitDefault is the default setting for this version of TensorFlow - // and corresponds to the "DEFAULT" level in the Python API. - // Currently the default is `JitOff`, but it will change to `JitOn` in a - // future version. - JitDefault JitLevel = 0 - // JitOff disables just-in-time compilation and will continue to do so - // even after JIT compilation is enabled by default. - JitOff JitLevel = -1 - // JitOn enables just-in-time compilation. It is a synonym for the - // "ON_1" optimization level in the Python API. - JitOn JitLevel = 1 -) - -// Config represents session parameters as encoded in the tensorflow.ConfigProto -// protocol buffer message. -type Config struct { - // GlobalJitLevel controls the degree of optimization that the XLA just-in-time - // compiler will perform. The default is currently "off", but it is expected - // to change to "on" in a future version of TensorFlow. - GlobalJitLevel JitLevel - - // AllowGPUMemoryGrowth controls whether the TensorFlow memory allocator - // pre-allocates the entire specified GPU memory region or instead starts - // with a small block of GPU memory and grows its memory usage as needed. - AllowGPUMemoryGrowth bool - - // NumCPUs is the maximum number of CPU devices that the session will use. - // A value of 0 means "let the system pick an appropriate number" - NumCPUs int - - // This struct only exposes the session options available via TensorFlow's - // experimental C API function `TF_CreateConfig()`. - // TODO(frreiss): Add additional options here as more session options are - // exposed via the C API. -} - -// Bytes generates a serialized ConfigOptions protobuf for use in the `Config` -// field of a `SessionOptions` struct. -func (c *Config) Bytes() []byte { - // The C API expects an unsigned char that is 0 if XLA compilation is off and - // nonzero otherwise. - // There is currently no way in the C API to specify "use TensorFlow's default - // JIT level". The translation logic here ensures that the zero value of - // c.GlobalJitLevel means the same as the default value of - // OptimizerOptions.global_jit_level in the Python API. - enableXLACompilationAsChar := C.uchar(0) - switch c.GlobalJitLevel { - case JitDefault: - // TODO(frreiss): When the semantics of GlobalJitLevel.DEFAULT change to - // "on", uncomment the following line. - // enableXLACompilationAsChar = C.uchar(1) - case JitOn: - enableXLACompilationAsChar = C.uchar(1) - } - gpuMemoryAllowGrowthAsChar := C.uchar(0) - if c.AllowGPUMemoryGrowth { - gpuMemoryAllowGrowthAsChar = 1 - } - // The C API doesn't currently have a way to say "let the system pick how many - // CPUs to use," so detect the number of CPUs here. - numCPUDevicesAsUint := C.uint(runtime.NumCPU()) - if c.NumCPUs > 0 { - numCPUDevicesAsUint = C.uint(c.NumCPUs) - } - buf := C.TF_CreateConfig(enableXLACompilationAsChar, gpuMemoryAllowGrowthAsChar, numCPUDevicesAsUint) - defer C.TF_DeleteBuffer(buf) - // Copy out of C memory. - return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length)) -} - // cRunArgs translates the arguments to Session.Run and PartialRun.Run into // values suitable for C library calls. type cRunArgs struct { diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index 510108de1ae..c9bda001671 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -250,15 +250,27 @@ func ExamplePartialRun() { } func TestSessionConfig(t *testing.T) { - // Exercise SessionOptions and Config structs + // Exercise SessionOptions. + // Arguably, a better API would be for SessionOptions.Config to be the + // type generated by the protocol buffer compiler. But for now, the + // tensorflow package continues to be independent of protocol buffers + // and this test exercises the option since the implementation has a + // nuanced conversion to C types. + // + // Till then, the []byte form of Config here was generated using a toy + // tensorflow Python program: + /* + import tensorflow + c = tensorflow.ConfigProto() + c.intra_op_parallelism_threads = 1 + print c.SerializeToString() + */ graph := NewGraph() c, err := Const(graph, "Const", int32(14)) if err != nil { t.Fatal(err) } - // Use the zero values for Config.GlobalJitLevel and NumCPUs - config := Config{AllowGPUMemoryGrowth: true} - opts := SessionOptions{Config: config.Bytes()} + opts := SessionOptions{Config: []byte("(\x01")} s, err := NewSession(graph, &opts) if err != nil { t.Fatal(err)