Go: Provide a mechanism to configure the Session.
A Session is configured using the ConfigProto protocol buffer. For now, continuing with attempts to keep the 'tensorflow' go package free of any protocol buffer dependencies, SessionOptions uses a serialized representation of this message. This choice might make sense to revisit. Change: 148750535
This commit is contained in:
parent
5a31e9c8bd
commit
49a4ebbf3c
@ -45,7 +45,11 @@ type SavedModel struct {
|
||||
// https://www.tensorflow.org/code/tensorflow/python/saved_model/
|
||||
func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error) {
|
||||
status := newStatus()
|
||||
cOpt := options.c()
|
||||
cOpt, doneOpt, err := options.c()
|
||||
defer doneOpt()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cExportDir := C.CString(exportDir)
|
||||
cTags := make([]*C.char, len(tags))
|
||||
for i := range tags {
|
||||
@ -58,7 +62,6 @@ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*
|
||||
C.free(unsafe.Pointer(cTags[i]))
|
||||
}
|
||||
C.free(unsafe.Pointer(cExportDir))
|
||||
C.TF_DeleteSessionOptions(cOpt)
|
||||
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
|
@ -20,6 +20,7 @@ import "C"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"unsafe"
|
||||
@ -47,9 +48,12 @@ type Session struct {
|
||||
// options may be nil to use the default options.
|
||||
func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
|
||||
status := newStatus()
|
||||
cOpt := options.c()
|
||||
cOpt, doneOpt, err := options.c()
|
||||
defer doneOpt()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cSess := C.TF_NewSession(graph.c, cOpt, status.c)
|
||||
C.TF_DeleteSessionOptions(cOpt)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -243,19 +247,42 @@ type SessionOptions struct {
|
||||
// If the session disconnects from the remote process during its
|
||||
// lifetime, session calls may fail immediately.
|
||||
Target string
|
||||
|
||||
// Config is a binary-serialized representation of the
|
||||
// tensorflow.ConfigProto protocol message
|
||||
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
|
||||
Config []byte
|
||||
}
|
||||
|
||||
// c converts the SessionOptions to the C API's TF_SessionOptions. Callers must
|
||||
// deallocate by calling C.TF_DeleteSessionOptions().
|
||||
func (o *SessionOptions) c() *C.TF_SessionOptions {
|
||||
// deallocate by calling the returned done() closure.
|
||||
func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) {
|
||||
opt := C.TF_NewSessionOptions()
|
||||
if o == nil {
|
||||
return opt
|
||||
return opt, func() { C.TF_DeleteSessionOptions(opt) }, nil
|
||||
}
|
||||
t := C.CString(o.Target)
|
||||
C.TF_SetTarget(opt, t)
|
||||
C.free(unsafe.Pointer(t))
|
||||
return opt
|
||||
|
||||
var cConfig unsafe.Pointer
|
||||
if sz := len(o.Config); sz > 0 {
|
||||
status := newStatus()
|
||||
// Copying into C-memory is the simplest thing to do in terms
|
||||
// of memory safety and cgo rules ("C code may not keep a copy
|
||||
// of a Go pointer after the call returns" from
|
||||
// https://golang.org/cmd/cgo/#hdr-Passing_pointers).
|
||||
cConfig = C.CBytes(o.Config)
|
||||
C.TF_SetConfig(opt, cConfig, C.size_t(sz), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
C.TF_DeleteSessionOptions(opt)
|
||||
return nil, func() {}, fmt.Errorf("invalid SessionOptions.Config: %v", err)
|
||||
}
|
||||
}
|
||||
return opt, func() {
|
||||
C.TF_DeleteSessionOptions(opt)
|
||||
C.free(cConfig)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
|
||||
|
@ -246,3 +246,38 @@ func ExamplePartialRun() {
|
||||
fmt.Println(v1, v2)
|
||||
// Output: 3 10
|
||||
}
|
||||
|
||||
func TestSessionConfig(t *testing.T) {
|
||||
// Exercise SessionOptions.
|
||||
// Arguably, a better API would be for SessionOptions.Config to be the
|
||||
// type generated by the protocol buffer compiler. But for now, the
|
||||
// tensorflow package continues to be independent of protocol buffers
|
||||
// and this test exercises the option since the implementation has a
|
||||
// nuanced conversion to C types.
|
||||
//
|
||||
// Till then, the []byte form of Config here was generated using a toy
|
||||
// tensorflow Python program:
|
||||
/*
|
||||
import tensorflow
|
||||
c = tensorflow.ConfigProto()
|
||||
c.intra_op_parallelism_threads = 1
|
||||
print c.SerializeToString()
|
||||
*/
|
||||
graph := NewGraph()
|
||||
c, err := Const(graph, "Const", int32(14))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
opts := SessionOptions{Config: []byte("(\x01")}
|
||||
s, err := NewSession(graph, &opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
output, err := s.Run(nil, []Output{c}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if output[0].Value().(int32) != 14 {
|
||||
t.Fatalf("Got %v, want -1", output[0].Value())
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user