From 085102c2e2947d76056b6363da96c55ecd838e6c Mon Sep 17 00:00:00 2001 From: Jonathan Hseu Date: Wed, 8 Feb 2017 11:55:21 -0800 Subject: [PATCH] SavedModel support in Go. Change: 146938337 --- tensorflow/go/session.go | 38 +++++++++++++++++++++++++++++++++++ tensorflow/go/session_test.go | 12 +++++++++++ 2 files changed, 50 insertions(+) diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index dd629441efa..c29b6e0b769 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -59,6 +59,44 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) { return s, nil } +// LoadSavedModel creates a new Session from a model previously exported to a +// directory on disk. +// +// Exported models contain a set of graphs and variable values. Tags in the +// model identify a single graph. LoadSessionFromSavedModel initializes a +// session with the identified graph and with variables initialized to saved +// values. +// +// The tensorflow package currently does not have the ability to export a model +// to a directory from Go. This function thus currently targets loading models +// exported in other languages, such as using tf.saved_model.builder in Python. +// See: +// https://www.tensorflow.org/code/tensorflow/python/saved_model/ +func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*Session, *Graph, error) { + status := newStatus() + cOpt := options.c() + cExportDir := C.CString(exportDir) + cTags := make([]*C.char, len(tags)) + for i := range tags { + cTags[i] = C.CString(tags[i]) + } + graph := NewGraph() + // TODO(jhseu): Add support for run_options and meta_graph_def. + cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c) + for i := range cTags { + C.free(unsafe.Pointer(cTags[i])) + } + C.free(unsafe.Pointer(cExportDir)) + C.TF_DeleteSessionOptions(cOpt) + + if err := status.Err(); err != nil { + return nil, nil, err + } + s := &Session{c: cSess} + runtime.SetFinalizer(s, func(s *Session) { s.Close() }) + return s, graph, nil +} + // Run the graph with the associated session starting with the supplied inputs. // inputs and outputs may be set to nil. Runs, but does not return Tensors // for operations specified in targets. diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index 14ecca402b2..ccd7d85295b 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -181,3 +181,15 @@ func TestConcurrency(t *testing.T) { t.Errorf("Close() 2: %v", err) } } + +func TestSavedModel(t *testing.T) { + _, graph, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", []string{"serve"}, nil) + if err != nil { + t.Fatalf("LoadSavedModel(): %v", err) + } + if op := graph.Operation("y"); op == nil { + t.Fatalf("\"y\" not found in graph") + } + // TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a + // more thorough test when the generated protobufs are available. +}