Address review comments and add some more API docs
This commit is contained in:
parent
3b3397c609
commit
aaec622b8a
@ -316,6 +316,11 @@ type SessionOptions struct {
|
||||
// Config is a binary-serialized representation of the
|
||||
// tensorflow.ConfigProto protocol message
|
||||
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
|
||||
// You can populate this field in three ways. You can use the zero value
|
||||
// of this field, which configures the session with a default set of
|
||||
// options. Or you create a `Config` struct with your options and call that
|
||||
// struct's `Bytes()` method. Or you can generate a byte string outside of
|
||||
// Go and paste that string into your Go program as a string literal.
|
||||
Config []byte
|
||||
}
|
||||
|
||||
@ -355,15 +360,17 @@ func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error)
|
||||
type JitLevel int
|
||||
|
||||
const (
|
||||
// DEFAULT is the default setting for this version of TensorFlow,
|
||||
// Currently the default is OFF, but it will change to ON in a future
|
||||
// version.
|
||||
DEFAULT JitLevel = 0
|
||||
// OFF disables just-in-time compilation and will continue to do so
|
||||
// JitDefault is the default setting for this version of TensorFlow
|
||||
// and corresponds to the "DEFAULT" level in the Python API.
|
||||
// Currently the default is `JitOff`, but it will change to `JitOn` in a
|
||||
// future version.
|
||||
JitDefault JitLevel = 0
|
||||
// JitOff disables just-in-time compilation and will continue to do so
|
||||
// even after JIT compilation is enabled by default.
|
||||
OFF JitLevel = -1
|
||||
// ON is a synonym for the ON_1 optimization level.
|
||||
ON JitLevel = 1
|
||||
JitOff JitLevel = -1
|
||||
// JitOn enables just-in-time compilation. It is a synonym for the
|
||||
// "ON_1" optimization level in the Python API.
|
||||
JitOn JitLevel = 1
|
||||
)
|
||||
|
||||
// Config represents session parameters as encoded in the tensorflow.ConfigProto
|
||||
@ -395,16 +402,16 @@ func (c *Config) Bytes() []byte {
|
||||
// The C API expects an unsigned char that is 0 if XLA compilation is off and
|
||||
// nonzero otherwise.
|
||||
// There is currently no way in the C API to specify "use TensorFlow's default
|
||||
// JIT level". The translation logic here ensures that the zero value of
|
||||
// JIT level". The translation logic here ensures that the zero value of
|
||||
// c.GlobalJitLevel means the same as the default value of
|
||||
// OptimizerOptions.global_jit_level in the Python API.
|
||||
enableXLACompilationAsChar := C.uchar(0)
|
||||
switch c.GlobalJitLevel {
|
||||
case DEFAULT:
|
||||
case JitDefault:
|
||||
// TODO(frreiss): When the semantics of GlobalJitLevel.DEFAULT change to
|
||||
// "on", uncomment the following line.
|
||||
// enableXLACompilationAsChar = C.uchar(1)
|
||||
case ON:
|
||||
case JitOn:
|
||||
enableXLACompilationAsChar = C.uchar(1)
|
||||
}
|
||||
gpuMemoryAllowGrowthAsChar := C.uchar(0)
|
||||
|
Loading…
Reference in New Issue
Block a user