SavedModel support in Go.
Change: 146938337
This commit is contained in:
parent
42dc6764a0
commit
085102c2e2
@ -59,6 +59,44 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// LoadSavedModel creates a new Session from a model previously exported to a
|
||||
// directory on disk.
|
||||
//
|
||||
// Exported models contain a set of graphs and variable values. Tags in the
|
||||
// model identify a single graph. LoadSessionFromSavedModel initializes a
|
||||
// session with the identified graph and with variables initialized to saved
|
||||
// values.
|
||||
//
|
||||
// The tensorflow package currently does not have the ability to export a model
|
||||
// to a directory from Go. This function thus currently targets loading models
|
||||
// exported in other languages, such as using tf.saved_model.builder in Python.
|
||||
// See:
|
||||
// https://www.tensorflow.org/code/tensorflow/python/saved_model/
|
||||
func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*Session, *Graph, error) {
|
||||
status := newStatus()
|
||||
cOpt := options.c()
|
||||
cExportDir := C.CString(exportDir)
|
||||
cTags := make([]*C.char, len(tags))
|
||||
for i := range tags {
|
||||
cTags[i] = C.CString(tags[i])
|
||||
}
|
||||
graph := NewGraph()
|
||||
// TODO(jhseu): Add support for run_options and meta_graph_def.
|
||||
cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c)
|
||||
for i := range cTags {
|
||||
C.free(unsafe.Pointer(cTags[i]))
|
||||
}
|
||||
C.free(unsafe.Pointer(cExportDir))
|
||||
C.TF_DeleteSessionOptions(cOpt)
|
||||
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
s := &Session{c: cSess}
|
||||
runtime.SetFinalizer(s, func(s *Session) { s.Close() })
|
||||
return s, graph, nil
|
||||
}
|
||||
|
||||
// Run the graph with the associated session starting with the supplied inputs.
|
||||
// inputs and outputs may be set to nil. Runs, but does not return Tensors
|
||||
// for operations specified in targets.
|
||||
|
@ -181,3 +181,15 @@ func TestConcurrency(t *testing.T) {
|
||||
t.Errorf("Close() 2: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSavedModel(t *testing.T) {
|
||||
_, graph, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", []string{"serve"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadSavedModel(): %v", err)
|
||||
}
|
||||
if op := graph.Operation("y"); op == nil {
|
||||
t.Fatalf("\"y\" not found in graph")
|
||||
}
|
||||
// TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
|
||||
// more thorough test when the generated protobufs are available.
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user