Automated rollback of commit 0bc46a1602
. Revert #26682.
PiperOrigin-RevId: 239098662
This commit is contained in:
parent
dc0137f16a
commit
6e9cb400d1
@ -18,7 +18,6 @@ package tensorflow
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "tensorflow/c/c_api.h"
|
||||
// #include "tensorflow/c/c_api_experimental.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
@ -316,11 +315,6 @@ type SessionOptions struct {
|
||||
// Config is a binary-serialized representation of the
|
||||
// tensorflow.ConfigProto protocol message
|
||||
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
|
||||
// You can populate this field in three ways. You can use the zero value
|
||||
// of this field, which configures the session with a default set of
|
||||
// options. Or you create a `Config` struct with your options and call that
|
||||
// struct's `Bytes()` method. Or you can generate a byte string outside of
|
||||
// Go and paste that string into your Go program as a string literal.
|
||||
Config []byte
|
||||
}
|
||||
|
||||
@ -355,81 +349,6 @@ func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// JitLevel represents the level of optimization that the XLA compiler
|
||||
// performs during just-in-time compilation.
|
||||
type JitLevel int
|
||||
|
||||
const (
|
||||
// JitDefault is the default setting for this version of TensorFlow
|
||||
// and corresponds to the "DEFAULT" level in the Python API.
|
||||
// Currently the default is `JitOff`, but it will change to `JitOn` in a
|
||||
// future version.
|
||||
JitDefault JitLevel = 0
|
||||
// JitOff disables just-in-time compilation and will continue to do so
|
||||
// even after JIT compilation is enabled by default.
|
||||
JitOff JitLevel = -1
|
||||
// JitOn enables just-in-time compilation. It is a synonym for the
|
||||
// "ON_1" optimization level in the Python API.
|
||||
JitOn JitLevel = 1
|
||||
)
|
||||
|
||||
// Config represents session parameters as encoded in the tensorflow.ConfigProto
|
||||
// protocol buffer message.
|
||||
type Config struct {
|
||||
// GlobalJitLevel controls the degree of optimization that the XLA just-in-time
|
||||
// compiler will perform. The default is currently "off", but it is expected
|
||||
// to change to "on" in a future version of TensorFlow.
|
||||
GlobalJitLevel JitLevel
|
||||
|
||||
// AllowGPUMemoryGrowth controls whether the TensorFlow memory allocator
|
||||
// pre-allocates the entire specified GPU memory region or instead starts
|
||||
// with a small block of GPU memory and grows its memory usage as needed.
|
||||
AllowGPUMemoryGrowth bool
|
||||
|
||||
// NumCPUs is the maximum number of CPU devices that the session will use.
|
||||
// A value of 0 means "let the system pick an appropriate number"
|
||||
NumCPUs int
|
||||
|
||||
// This struct only exposes the session options available via TensorFlow's
|
||||
// experimental C API function `TF_CreateConfig()`.
|
||||
// TODO(frreiss): Add additional options here as more session options are
|
||||
// exposed via the C API.
|
||||
}
|
||||
|
||||
// Bytes generates a serialized ConfigOptions protobuf for use in the `Config`
|
||||
// field of a `SessionOptions` struct.
|
||||
func (c *Config) Bytes() []byte {
|
||||
// The C API expects an unsigned char that is 0 if XLA compilation is off and
|
||||
// nonzero otherwise.
|
||||
// There is currently no way in the C API to specify "use TensorFlow's default
|
||||
// JIT level". The translation logic here ensures that the zero value of
|
||||
// c.GlobalJitLevel means the same as the default value of
|
||||
// OptimizerOptions.global_jit_level in the Python API.
|
||||
enableXLACompilationAsChar := C.uchar(0)
|
||||
switch c.GlobalJitLevel {
|
||||
case JitDefault:
|
||||
// TODO(frreiss): When the semantics of GlobalJitLevel.DEFAULT change to
|
||||
// "on", uncomment the following line.
|
||||
// enableXLACompilationAsChar = C.uchar(1)
|
||||
case JitOn:
|
||||
enableXLACompilationAsChar = C.uchar(1)
|
||||
}
|
||||
gpuMemoryAllowGrowthAsChar := C.uchar(0)
|
||||
if c.AllowGPUMemoryGrowth {
|
||||
gpuMemoryAllowGrowthAsChar = 1
|
||||
}
|
||||
// The C API doesn't currently have a way to say "let the system pick how many
|
||||
// CPUs to use," so detect the number of CPUs here.
|
||||
numCPUDevicesAsUint := C.uint(runtime.NumCPU())
|
||||
if c.NumCPUs > 0 {
|
||||
numCPUDevicesAsUint = C.uint(c.NumCPUs)
|
||||
}
|
||||
buf := C.TF_CreateConfig(enableXLACompilationAsChar, gpuMemoryAllowGrowthAsChar, numCPUDevicesAsUint)
|
||||
defer C.TF_DeleteBuffer(buf)
|
||||
// Copy out of C memory.
|
||||
return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length))
|
||||
}
|
||||
|
||||
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
|
||||
// values suitable for C library calls.
|
||||
type cRunArgs struct {
|
||||
|
@ -250,15 +250,27 @@ func ExamplePartialRun() {
|
||||
}
|
||||
|
||||
func TestSessionConfig(t *testing.T) {
|
||||
// Exercise SessionOptions and Config structs
|
||||
// Exercise SessionOptions.
|
||||
// Arguably, a better API would be for SessionOptions.Config to be the
|
||||
// type generated by the protocol buffer compiler. But for now, the
|
||||
// tensorflow package continues to be independent of protocol buffers
|
||||
// and this test exercises the option since the implementation has a
|
||||
// nuanced conversion to C types.
|
||||
//
|
||||
// Till then, the []byte form of Config here was generated using a toy
|
||||
// tensorflow Python program:
|
||||
/*
|
||||
import tensorflow
|
||||
c = tensorflow.ConfigProto()
|
||||
c.intra_op_parallelism_threads = 1
|
||||
print c.SerializeToString()
|
||||
*/
|
||||
graph := NewGraph()
|
||||
c, err := Const(graph, "Const", int32(14))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Use the zero values for Config.GlobalJitLevel and NumCPUs
|
||||
config := Config{AllowGPUMemoryGrowth: true}
|
||||
opts := SessionOptions{Config: config.Bytes()}
|
||||
opts := SessionOptions{Config: []byte("(\x01")}
|
||||
s, err := NewSession(graph, &opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
Loading…
Reference in New Issue
Block a user