From 49a4ebbf3cb307c513653427b32f30ad35855094 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Tue, 28 Feb 2017 03:01:11 -0800 Subject: [PATCH] Go: Provide a mechanism to configure the Session. A Session is configured using the ConfigProto protocol buffer. For now, continuing with attempts to keep the 'tensorflow' go package free of any protocol buffer dependencies, SessionOptions uses a serialized representation of this message. This choice might make sense to revisit. Change: 148750535 --- tensorflow/go/saved_model.go | 7 +++++-- tensorflow/go/session.go | 39 +++++++++++++++++++++++++++++------ tensorflow/go/session_test.go | 35 +++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/tensorflow/go/saved_model.go b/tensorflow/go/saved_model.go index dd8ad307b96..32e40d9a952 100644 --- a/tensorflow/go/saved_model.go +++ b/tensorflow/go/saved_model.go @@ -45,7 +45,11 @@ type SavedModel struct { // https://www.tensorflow.org/code/tensorflow/python/saved_model/ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error) { status := newStatus() - cOpt := options.c() + cOpt, doneOpt, err := options.c() + defer doneOpt() + if err != nil { + return nil, err + } cExportDir := C.CString(exportDir) cTags := make([]*C.char, len(tags)) for i := range tags { @@ -58,7 +62,6 @@ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (* C.free(unsafe.Pointer(cTags[i])) } C.free(unsafe.Pointer(cExportDir)) - C.TF_DeleteSessionOptions(cOpt) if err := status.Err(); err != nil { return nil, err diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index ef357cb5205..5a6e1e37ad3 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -20,6 +20,7 @@ import "C" import ( "errors" + "fmt" "runtime" "sync" "unsafe" @@ -47,9 +48,12 @@ type Session struct { // options may be nil to use the default options. func NewSession(graph *Graph, options *SessionOptions) (*Session, error) { status := newStatus() - cOpt := options.c() + cOpt, doneOpt, err := options.c() + defer doneOpt() + if err != nil { + return nil, err + } cSess := C.TF_NewSession(graph.c, cOpt, status.c) - C.TF_DeleteSessionOptions(cOpt) if err := status.Err(); err != nil { return nil, err } @@ -243,19 +247,42 @@ type SessionOptions struct { // If the session disconnects from the remote process during its // lifetime, session calls may fail immediately. Target string + + // Config is a binary-serialized representation of the + // tensorflow.ConfigProto protocol message + // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto). + Config []byte } // c converts the SessionOptions to the C API's TF_SessionOptions. Callers must -// deallocate by calling C.TF_DeleteSessionOptions(). -func (o *SessionOptions) c() *C.TF_SessionOptions { +// deallocate by calling the returned done() closure. +func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) { opt := C.TF_NewSessionOptions() if o == nil { - return opt + return opt, func() { C.TF_DeleteSessionOptions(opt) }, nil } t := C.CString(o.Target) C.TF_SetTarget(opt, t) C.free(unsafe.Pointer(t)) - return opt + + var cConfig unsafe.Pointer + if sz := len(o.Config); sz > 0 { + status := newStatus() + // Copying into C-memory is the simplest thing to do in terms + // of memory safety and cgo rules ("C code may not keep a copy + // of a Go pointer after the call returns" from + // https://golang.org/cmd/cgo/#hdr-Passing_pointers). + cConfig = C.CBytes(o.Config) + C.TF_SetConfig(opt, cConfig, C.size_t(sz), status.c) + if err := status.Err(); err != nil { + C.TF_DeleteSessionOptions(opt) + return nil, func() {}, fmt.Errorf("invalid SessionOptions.Config: %v", err) + } + } + return opt, func() { + C.TF_DeleteSessionOptions(opt) + C.free(cConfig) + }, nil } // cRunArgs translates the arguments to Session.Run and PartialRun.Run into diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index 9afa2be3b4f..4c1b862e1f7 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -246,3 +246,38 @@ func ExamplePartialRun() { fmt.Println(v1, v2) // Output: 3 10 } + +func TestSessionConfig(t *testing.T) { + // Exercise SessionOptions. + // Arguably, a better API would be for SessionOptions.Config to be the + // type generated by the protocol buffer compiler. But for now, the + // tensorflow package continues to be independent of protocol buffers + // and this test exercises the option since the implementation has a + // nuanced conversion to C types. + // + // Till then, the []byte form of Config here was generated using a toy + // tensorflow Python program: + /* + import tensorflow + c = tensorflow.ConfigProto() + c.intra_op_parallelism_threads = 1 + print c.SerializeToString() + */ + graph := NewGraph() + c, err := Const(graph, "Const", int32(14)) + if err != nil { + t.Fatal(err) + } + opts := SessionOptions{Config: []byte("(\x01")} + s, err := NewSession(graph, &opts) + if err != nil { + t.Fatal(err) + } + output, err := s.Run(nil, []Output{c}, nil) + if err != nil { + t.Fatal(err) + } + if output[0].Value().(int32) != 14 { + t.Fatalf("Got %v, want -1", output[0].Value()) + } +}