Merge pull request #26682 from frreiss:issue-go_proto
PiperOrigin-RevId: 239070046
This commit is contained in:
commit
0bc46a1602
@ -18,6 +18,7 @@ package tensorflow
|
|||||||
|
|
||||||
// #include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
// #include "tensorflow/c/c_api.h"
|
// #include "tensorflow/c/c_api.h"
|
||||||
|
// #include "tensorflow/c/c_api_experimental.h"
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -315,6 +316,11 @@ type SessionOptions struct {
|
|||||||
// Config is a binary-serialized representation of the
|
// Config is a binary-serialized representation of the
|
||||||
// tensorflow.ConfigProto protocol message
|
// tensorflow.ConfigProto protocol message
|
||||||
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
|
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
|
||||||
|
// You can populate this field in three ways. You can use the zero value
|
||||||
|
// of this field, which configures the session with a default set of
|
||||||
|
// options. Or you create a `Config` struct with your options and call that
|
||||||
|
// struct's `Bytes()` method. Or you can generate a byte string outside of
|
||||||
|
// Go and paste that string into your Go program as a string literal.
|
||||||
Config []byte
|
Config []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -349,6 +355,81 @@ func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error)
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JitLevel represents the level of optimization that the XLA compiler
|
||||||
|
// performs during just-in-time compilation.
|
||||||
|
type JitLevel int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// JitDefault is the default setting for this version of TensorFlow
|
||||||
|
// and corresponds to the "DEFAULT" level in the Python API.
|
||||||
|
// Currently the default is `JitOff`, but it will change to `JitOn` in a
|
||||||
|
// future version.
|
||||||
|
JitDefault JitLevel = 0
|
||||||
|
// JitOff disables just-in-time compilation and will continue to do so
|
||||||
|
// even after JIT compilation is enabled by default.
|
||||||
|
JitOff JitLevel = -1
|
||||||
|
// JitOn enables just-in-time compilation. It is a synonym for the
|
||||||
|
// "ON_1" optimization level in the Python API.
|
||||||
|
JitOn JitLevel = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config represents session parameters as encoded in the tensorflow.ConfigProto
|
||||||
|
// protocol buffer message.
|
||||||
|
type Config struct {
|
||||||
|
// GlobalJitLevel controls the degree of optimization that the XLA just-in-time
|
||||||
|
// compiler will perform. The default is currently "off", but it is expected
|
||||||
|
// to change to "on" in a future version of TensorFlow.
|
||||||
|
GlobalJitLevel JitLevel
|
||||||
|
|
||||||
|
// AllowGPUMemoryGrowth controls whether the TensorFlow memory allocator
|
||||||
|
// pre-allocates the entire specified GPU memory region or instead starts
|
||||||
|
// with a small block of GPU memory and grows its memory usage as needed.
|
||||||
|
AllowGPUMemoryGrowth bool
|
||||||
|
|
||||||
|
// NumCPUs is the maximum number of CPU devices that the session will use.
|
||||||
|
// A value of 0 means "let the system pick an appropriate number"
|
||||||
|
NumCPUs int
|
||||||
|
|
||||||
|
// This struct only exposes the session options available via TensorFlow's
|
||||||
|
// experimental C API function `TF_CreateConfig()`.
|
||||||
|
// TODO(frreiss): Add additional options here as more session options are
|
||||||
|
// exposed via the C API.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes generates a serialized ConfigOptions protobuf for use in the `Config`
|
||||||
|
// field of a `SessionOptions` struct.
|
||||||
|
func (c *Config) Bytes() []byte {
|
||||||
|
// The C API expects an unsigned char that is 0 if XLA compilation is off and
|
||||||
|
// nonzero otherwise.
|
||||||
|
// There is currently no way in the C API to specify "use TensorFlow's default
|
||||||
|
// JIT level". The translation logic here ensures that the zero value of
|
||||||
|
// c.GlobalJitLevel means the same as the default value of
|
||||||
|
// OptimizerOptions.global_jit_level in the Python API.
|
||||||
|
enableXLACompilationAsChar := C.uchar(0)
|
||||||
|
switch c.GlobalJitLevel {
|
||||||
|
case JitDefault:
|
||||||
|
// TODO(frreiss): When the semantics of GlobalJitLevel.DEFAULT change to
|
||||||
|
// "on", uncomment the following line.
|
||||||
|
// enableXLACompilationAsChar = C.uchar(1)
|
||||||
|
case JitOn:
|
||||||
|
enableXLACompilationAsChar = C.uchar(1)
|
||||||
|
}
|
||||||
|
gpuMemoryAllowGrowthAsChar := C.uchar(0)
|
||||||
|
if c.AllowGPUMemoryGrowth {
|
||||||
|
gpuMemoryAllowGrowthAsChar = 1
|
||||||
|
}
|
||||||
|
// The C API doesn't currently have a way to say "let the system pick how many
|
||||||
|
// CPUs to use," so detect the number of CPUs here.
|
||||||
|
numCPUDevicesAsUint := C.uint(runtime.NumCPU())
|
||||||
|
if c.NumCPUs > 0 {
|
||||||
|
numCPUDevicesAsUint = C.uint(c.NumCPUs)
|
||||||
|
}
|
||||||
|
buf := C.TF_CreateConfig(enableXLACompilationAsChar, gpuMemoryAllowGrowthAsChar, numCPUDevicesAsUint)
|
||||||
|
defer C.TF_DeleteBuffer(buf)
|
||||||
|
// Copy out of C memory.
|
||||||
|
return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length))
|
||||||
|
}
|
||||||
|
|
||||||
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
|
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
|
||||||
// values suitable for C library calls.
|
// values suitable for C library calls.
|
||||||
type cRunArgs struct {
|
type cRunArgs struct {
|
||||||
|
@ -250,27 +250,15 @@ func ExamplePartialRun() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionConfig(t *testing.T) {
|
func TestSessionConfig(t *testing.T) {
|
||||||
// Exercise SessionOptions.
|
// Exercise SessionOptions and Config structs
|
||||||
// Arguably, a better API would be for SessionOptions.Config to be the
|
|
||||||
// type generated by the protocol buffer compiler. But for now, the
|
|
||||||
// tensorflow package continues to be independent of protocol buffers
|
|
||||||
// and this test exercises the option since the implementation has a
|
|
||||||
// nuanced conversion to C types.
|
|
||||||
//
|
|
||||||
// Till then, the []byte form of Config here was generated using a toy
|
|
||||||
// tensorflow Python program:
|
|
||||||
/*
|
|
||||||
import tensorflow
|
|
||||||
c = tensorflow.ConfigProto()
|
|
||||||
c.intra_op_parallelism_threads = 1
|
|
||||||
print c.SerializeToString()
|
|
||||||
*/
|
|
||||||
graph := NewGraph()
|
graph := NewGraph()
|
||||||
c, err := Const(graph, "Const", int32(14))
|
c, err := Const(graph, "Const", int32(14))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
opts := SessionOptions{Config: []byte("(\x01")}
|
// Use the zero values for Config.GlobalJitLevel and NumCPUs
|
||||||
|
config := Config{AllowGPUMemoryGrowth: true}
|
||||||
|
opts := SessionOptions{Config: config.Bytes()}
|
||||||
s, err := NewSession(graph, &opts)
|
s, err := NewSession(graph, &opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
Loading…
Reference in New Issue
Block a user