Go: Provide a mechanism to configure the Session.

A Session is configured using the ConfigProto protocol buffer.
For now, continuing with attempts to keep the 'tensorflow' go package
free of any protocol buffer dependencies, SessionOptions uses a serialized
representation of this message. This choice might make sense to revisit.
Change: 148750535
This commit is contained in:
Asim Shankar 2017-02-28 03:01:11 -08:00 committed by TensorFlower Gardener
parent 5a31e9c8bd
commit 49a4ebbf3c
3 changed files with 73 additions and 8 deletions

View File

@ -45,7 +45,11 @@ type SavedModel struct {
// https://www.tensorflow.org/code/tensorflow/python/saved_model/
func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error) {
status := newStatus()
cOpt := options.c()
cOpt, doneOpt, err := options.c()
defer doneOpt()
if err != nil {
return nil, err
}
cExportDir := C.CString(exportDir)
cTags := make([]*C.char, len(tags))
for i := range tags {
@ -58,7 +62,6 @@ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*
C.free(unsafe.Pointer(cTags[i]))
}
C.free(unsafe.Pointer(cExportDir))
C.TF_DeleteSessionOptions(cOpt)
if err := status.Err(); err != nil {
return nil, err

View File

@ -20,6 +20,7 @@ import "C"
import (
"errors"
"fmt"
"runtime"
"sync"
"unsafe"
@ -47,9 +48,12 @@ type Session struct {
// options may be nil to use the default options.
func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
status := newStatus()
cOpt := options.c()
cOpt, doneOpt, err := options.c()
defer doneOpt()
if err != nil {
return nil, err
}
cSess := C.TF_NewSession(graph.c, cOpt, status.c)
C.TF_DeleteSessionOptions(cOpt)
if err := status.Err(); err != nil {
return nil, err
}
@ -243,19 +247,42 @@ type SessionOptions struct {
// If the session disconnects from the remote process during its
// lifetime, session calls may fail immediately.
Target string
// Config is a binary-serialized representation of the
// tensorflow.ConfigProto protocol message
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
Config []byte
}
// c converts the SessionOptions to the C API's TF_SessionOptions. Callers must
// deallocate by calling C.TF_DeleteSessionOptions().
func (o *SessionOptions) c() *C.TF_SessionOptions {
// deallocate by calling the returned done() closure.
func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) {
opt := C.TF_NewSessionOptions()
if o == nil {
return opt
return opt, func() { C.TF_DeleteSessionOptions(opt) }, nil
}
t := C.CString(o.Target)
C.TF_SetTarget(opt, t)
C.free(unsafe.Pointer(t))
return opt
var cConfig unsafe.Pointer
if sz := len(o.Config); sz > 0 {
status := newStatus()
// Copying into C-memory is the simplest thing to do in terms
// of memory safety and cgo rules ("C code may not keep a copy
// of a Go pointer after the call returns" from
// https://golang.org/cmd/cgo/#hdr-Passing_pointers).
cConfig = C.CBytes(o.Config)
C.TF_SetConfig(opt, cConfig, C.size_t(sz), status.c)
if err := status.Err(); err != nil {
C.TF_DeleteSessionOptions(opt)
return nil, func() {}, fmt.Errorf("invalid SessionOptions.Config: %v", err)
}
}
return opt, func() {
C.TF_DeleteSessionOptions(opt)
C.free(cConfig)
}, nil
}
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into

View File

@ -246,3 +246,38 @@ func ExamplePartialRun() {
fmt.Println(v1, v2)
// Output: 3 10
}
func TestSessionConfig(t *testing.T) {
// Exercise SessionOptions.
// Arguably, a better API would be for SessionOptions.Config to be the
// type generated by the protocol buffer compiler. But for now, the
// tensorflow package continues to be independent of protocol buffers
// and this test exercises the option since the implementation has a
// nuanced conversion to C types.
//
// Till then, the []byte form of Config here was generated using a toy
// tensorflow Python program:
/*
import tensorflow
c = tensorflow.ConfigProto()
c.intra_op_parallelism_threads = 1
print c.SerializeToString()
*/
graph := NewGraph()
c, err := Const(graph, "Const", int32(14))
if err != nil {
t.Fatal(err)
}
opts := SessionOptions{Config: []byte("(\x01")}
s, err := NewSession(graph, &opts)
if err != nil {
t.Fatal(err)
}
output, err := s.Run(nil, []Output{c}, nil)
if err != nil {
t.Fatal(err)
}
if output[0].Value().(int32) != 14 {
t.Fatalf("Got %v, want -1", output[0].Value())
}
}