101 lines
3.3 KiB
Go
101 lines
3.3 KiB
Go
/*
|
|
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package tensorflow
|
|
|
|
import (
|
|
"fmt"
|
|
"runtime"
|
|
"unsafe"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
|
|
)
|
|
|
|
// #include <stdlib.h>
|
|
// #include "tensorflow/c/c_api.h"
|
|
import "C"
|
|
|
|
// SavedModel represents the contents of loaded SavedModel.
|
|
// TODO(jhseu): Add and document metagraphdef when we pregenerate protobufs.
|
|
type SavedModel struct {
|
|
Session *Session
|
|
Graph *Graph
|
|
Signatures map[string]Signature
|
|
}
|
|
|
|
// LoadSavedModel creates a new SavedModel from a model previously
|
|
// exported to a directory on disk.
|
|
//
|
|
// Exported models contain a set of graphs and, optionally, variable values.
|
|
// Tags in the model identify a single graph. LoadSavedModel initializes a
|
|
// session with the identified graph and with variables initialized to from the
|
|
// checkpoints on disk.
|
|
//
|
|
// The tensorflow package currently does not have the ability to export a model
|
|
// to a directory from Go. This function thus currently targets loading models
|
|
// exported in other languages, such as using tf.saved_model.builder in Python.
|
|
// See:
|
|
// https://www.tensorflow.org/code/tensorflow/python/saved_model/
|
|
func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error) {
|
|
status := newStatus()
|
|
cOpt, doneOpt, err := options.c()
|
|
defer doneOpt()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cExportDir := C.CString(exportDir)
|
|
if len(tags) == 0 {
|
|
return nil, fmt.Errorf("empty tags are not allowed")
|
|
}
|
|
cTags := make([]*C.char, len(tags))
|
|
for i := range tags {
|
|
cTags[i] = C.CString(tags[i])
|
|
}
|
|
graph := NewGraph()
|
|
metaGraphDefBuf := C.TF_NewBuffer()
|
|
defer C.TF_DeleteBuffer(metaGraphDefBuf)
|
|
// TODO(jhseu): Add support for run_options and meta_graph_def.
|
|
cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, metaGraphDefBuf, status.c)
|
|
for i := range cTags {
|
|
C.free(unsafe.Pointer(cTags[i]))
|
|
}
|
|
C.free(unsafe.Pointer(cExportDir))
|
|
|
|
metaGraphDefBytes := C.GoBytes(metaGraphDefBuf.data, C.int(metaGraphDefBuf.length))
|
|
metaGraphDef := new(corepb.MetaGraphDef)
|
|
if err := proto.Unmarshal(metaGraphDefBytes, metaGraphDef); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
signatures := generateSignatures(metaGraphDef.GetSignatureDef())
|
|
|
|
if err := status.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
s := &Session{c: cSess}
|
|
runtime.SetFinalizer(s, func(s *Session) { s.Close() })
|
|
return &SavedModel{Session: s, Graph: graph, Signatures: signatures}, nil
|
|
}
|
|
|
|
func generateSignatures(pb map[string]*corepb.SignatureDef) map[string]Signature {
|
|
signatures := make(map[string]Signature)
|
|
for name, signature := range pb {
|
|
signatures[name] = signatureDefFromProto(signature)
|
|
}
|
|
return signatures
|
|
}
|