Automated rollback of commit 0bc46a1602
. Revert #26682.
PiperOrigin-RevId: 239098662
This commit is contained in:
parent
dc0137f16a
commit
6e9cb400d1
@ -18,7 +18,6 @@ package tensorflow
|
|||||||
|
|
||||||
// #include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
// #include "tensorflow/c/c_api.h"
|
// #include "tensorflow/c/c_api.h"
|
||||||
// #include "tensorflow/c/c_api_experimental.h"
|
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -316,11 +315,6 @@ type SessionOptions struct {
|
|||||||
// Config is a binary-serialized representation of the
|
// Config is a binary-serialized representation of the
|
||||||
// tensorflow.ConfigProto protocol message
|
// tensorflow.ConfigProto protocol message
|
||||||
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
|
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
|
||||||
// You can populate this field in three ways. You can use the zero value
|
|
||||||
// of this field, which configures the session with a default set of
|
|
||||||
// options. Or you create a `Config` struct with your options and call that
|
|
||||||
// struct's `Bytes()` method. Or you can generate a byte string outside of
|
|
||||||
// Go and paste that string into your Go program as a string literal.
|
|
||||||
Config []byte
|
Config []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,81 +349,6 @@ func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error)
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// JitLevel represents the level of optimization that the XLA compiler
|
|
||||||
// performs during just-in-time compilation.
|
|
||||||
type JitLevel int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// JitDefault is the default setting for this version of TensorFlow
|
|
||||||
// and corresponds to the "DEFAULT" level in the Python API.
|
|
||||||
// Currently the default is `JitOff`, but it will change to `JitOn` in a
|
|
||||||
// future version.
|
|
||||||
JitDefault JitLevel = 0
|
|
||||||
// JitOff disables just-in-time compilation and will continue to do so
|
|
||||||
// even after JIT compilation is enabled by default.
|
|
||||||
JitOff JitLevel = -1
|
|
||||||
// JitOn enables just-in-time compilation. It is a synonym for the
|
|
||||||
// "ON_1" optimization level in the Python API.
|
|
||||||
JitOn JitLevel = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// Config represents session parameters as encoded in the tensorflow.ConfigProto
|
|
||||||
// protocol buffer message.
|
|
||||||
type Config struct {
|
|
||||||
// GlobalJitLevel controls the degree of optimization that the XLA just-in-time
|
|
||||||
// compiler will perform. The default is currently "off", but it is expected
|
|
||||||
// to change to "on" in a future version of TensorFlow.
|
|
||||||
GlobalJitLevel JitLevel
|
|
||||||
|
|
||||||
// AllowGPUMemoryGrowth controls whether the TensorFlow memory allocator
|
|
||||||
// pre-allocates the entire specified GPU memory region or instead starts
|
|
||||||
// with a small block of GPU memory and grows its memory usage as needed.
|
|
||||||
AllowGPUMemoryGrowth bool
|
|
||||||
|
|
||||||
// NumCPUs is the maximum number of CPU devices that the session will use.
|
|
||||||
// A value of 0 means "let the system pick an appropriate number"
|
|
||||||
NumCPUs int
|
|
||||||
|
|
||||||
// This struct only exposes the session options available via TensorFlow's
|
|
||||||
// experimental C API function `TF_CreateConfig()`.
|
|
||||||
// TODO(frreiss): Add additional options here as more session options are
|
|
||||||
// exposed via the C API.
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bytes generates a serialized ConfigOptions protobuf for use in the `Config`
|
|
||||||
// field of a `SessionOptions` struct.
|
|
||||||
func (c *Config) Bytes() []byte {
|
|
||||||
// The C API expects an unsigned char that is 0 if XLA compilation is off and
|
|
||||||
// nonzero otherwise.
|
|
||||||
// There is currently no way in the C API to specify "use TensorFlow's default
|
|
||||||
// JIT level". The translation logic here ensures that the zero value of
|
|
||||||
// c.GlobalJitLevel means the same as the default value of
|
|
||||||
// OptimizerOptions.global_jit_level in the Python API.
|
|
||||||
enableXLACompilationAsChar := C.uchar(0)
|
|
||||||
switch c.GlobalJitLevel {
|
|
||||||
case JitDefault:
|
|
||||||
// TODO(frreiss): When the semantics of GlobalJitLevel.DEFAULT change to
|
|
||||||
// "on", uncomment the following line.
|
|
||||||
// enableXLACompilationAsChar = C.uchar(1)
|
|
||||||
case JitOn:
|
|
||||||
enableXLACompilationAsChar = C.uchar(1)
|
|
||||||
}
|
|
||||||
gpuMemoryAllowGrowthAsChar := C.uchar(0)
|
|
||||||
if c.AllowGPUMemoryGrowth {
|
|
||||||
gpuMemoryAllowGrowthAsChar = 1
|
|
||||||
}
|
|
||||||
// The C API doesn't currently have a way to say "let the system pick how many
|
|
||||||
// CPUs to use," so detect the number of CPUs here.
|
|
||||||
numCPUDevicesAsUint := C.uint(runtime.NumCPU())
|
|
||||||
if c.NumCPUs > 0 {
|
|
||||||
numCPUDevicesAsUint = C.uint(c.NumCPUs)
|
|
||||||
}
|
|
||||||
buf := C.TF_CreateConfig(enableXLACompilationAsChar, gpuMemoryAllowGrowthAsChar, numCPUDevicesAsUint)
|
|
||||||
defer C.TF_DeleteBuffer(buf)
|
|
||||||
// Copy out of C memory.
|
|
||||||
return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length))
|
|
||||||
}
|
|
||||||
|
|
||||||
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
|
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
|
||||||
// values suitable for C library calls.
|
// values suitable for C library calls.
|
||||||
type cRunArgs struct {
|
type cRunArgs struct {
|
||||||
|
@ -250,15 +250,27 @@ func ExamplePartialRun() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionConfig(t *testing.T) {
|
func TestSessionConfig(t *testing.T) {
|
||||||
// Exercise SessionOptions and Config structs
|
// Exercise SessionOptions.
|
||||||
|
// Arguably, a better API would be for SessionOptions.Config to be the
|
||||||
|
// type generated by the protocol buffer compiler. But for now, the
|
||||||
|
// tensorflow package continues to be independent of protocol buffers
|
||||||
|
// and this test exercises the option since the implementation has a
|
||||||
|
// nuanced conversion to C types.
|
||||||
|
//
|
||||||
|
// Till then, the []byte form of Config here was generated using a toy
|
||||||
|
// tensorflow Python program:
|
||||||
|
/*
|
||||||
|
import tensorflow
|
||||||
|
c = tensorflow.ConfigProto()
|
||||||
|
c.intra_op_parallelism_threads = 1
|
||||||
|
print c.SerializeToString()
|
||||||
|
*/
|
||||||
graph := NewGraph()
|
graph := NewGraph()
|
||||||
c, err := Const(graph, "Const", int32(14))
|
c, err := Const(graph, "Const", int32(14))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
// Use the zero values for Config.GlobalJitLevel and NumCPUs
|
opts := SessionOptions{Config: []byte("(\x01")}
|
||||||
config := Config{AllowGPUMemoryGrowth: true}
|
|
||||||
opts := SessionOptions{Config: config.Bytes()}
|
|
||||||
s, err := NewSession(graph, &opts)
|
s, err := NewSession(graph, &opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
Loading…
Reference in New Issue
Block a user