SavedModel support in Go.

Change: 146938337
This commit is contained in:
Jonathan Hseu 2017-02-08 11:55:21 -08:00 committed by TensorFlower Gardener
parent 42dc6764a0
commit 085102c2e2
2 changed files with 50 additions and 0 deletions

View File

@ -59,6 +59,44 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
return s, nil
}
// LoadSavedModel creates a new Session from a model previously exported to a
// directory on disk.
//
// Exported models contain a set of graphs and variable values. Tags in the
// model identify a single graph. LoadSessionFromSavedModel initializes a
// session with the identified graph and with variables initialized to saved
// values.
//
// The tensorflow package currently does not have the ability to export a model
// to a directory from Go. This function thus currently targets loading models
// exported in other languages, such as using tf.saved_model.builder in Python.
// See:
// https://www.tensorflow.org/code/tensorflow/python/saved_model/
func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*Session, *Graph, error) {
status := newStatus()
cOpt := options.c()
cExportDir := C.CString(exportDir)
cTags := make([]*C.char, len(tags))
for i := range tags {
cTags[i] = C.CString(tags[i])
}
graph := NewGraph()
// TODO(jhseu): Add support for run_options and meta_graph_def.
cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c)
for i := range cTags {
C.free(unsafe.Pointer(cTags[i]))
}
C.free(unsafe.Pointer(cExportDir))
C.TF_DeleteSessionOptions(cOpt)
if err := status.Err(); err != nil {
return nil, nil, err
}
s := &Session{c: cSess}
runtime.SetFinalizer(s, func(s *Session) { s.Close() })
return s, graph, nil
}
// Run the graph with the associated session starting with the supplied inputs.
// inputs and outputs may be set to nil. Runs, but does not return Tensors
// for operations specified in targets.

View File

@ -181,3 +181,15 @@ func TestConcurrency(t *testing.T) {
t.Errorf("Close() 2: %v", err)
}
}
func TestSavedModel(t *testing.T) {
_, graph, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", []string{"serve"}, nil)
if err != nil {
t.Fatalf("LoadSavedModel(): %v", err)
}
if op := graph.Operation("y"); op == nil {
t.Fatalf("\"y\" not found in graph")
}
// TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
// more thorough test when the generated protobufs are available.
}