Add SignatureDefs to Go SavedModel API.
PiperOrigin-RevId: 281984017 Change-Id: Iefdf75ed88f54d97a0a7d210f5a42f3123205bf2
This commit is contained in:
parent
f0323528be
commit
f1a3c8af3c
@ -24,6 +24,8 @@ sh_test(
|
||||
"//tensorflow/c/eager:headers", # Eager C library header
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two", # Testdata for LoadSavedModel
|
||||
] + tf_shared_library_deps(),
|
||||
# TODO: Enable this test again once protos are supported by bazel.
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -30,6 +30,7 @@ from source.
|
||||
sudo apt-get install python swig python-numpy # Linux
|
||||
brew install swig # OS X with homebrew
|
||||
```
|
||||
- [Protocol buffer compiler (protoc) 3.x](https://github.com/google/protobuf/releases/)
|
||||
|
||||
### Build
|
||||
|
||||
@ -74,6 +75,7 @@ from source.
|
||||
4. Build and test:
|
||||
|
||||
```sh
|
||||
go generate github.com/tensorflow/tensorflow/tensorflow/go/op
|
||||
go test github.com/tensorflow/tensorflow/tensorflow/go
|
||||
```
|
||||
|
||||
|
@ -19,6 +19,10 @@ package tensorflow
|
||||
import (
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
tfpb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework"
|
||||
)
|
||||
|
||||
// #include <stdlib.h>
|
||||
@ -28,8 +32,9 @@ import "C"
|
||||
// SavedModel represents the contents of loaded SavedModel.
|
||||
// TODO(jhseu): Add and document metagraphdef when we pregenerate protobufs.
|
||||
type SavedModel struct {
|
||||
Session *Session
|
||||
Graph *Graph
|
||||
Session *Session
|
||||
Graph *Graph
|
||||
Signatures map[string]Signature
|
||||
}
|
||||
|
||||
// LoadSavedModel creates a new SavedModel from a model previously
|
||||
@ -58,17 +63,35 @@ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*
|
||||
cTags[i] = C.CString(tags[i])
|
||||
}
|
||||
graph := NewGraph()
|
||||
metaGraphDefBuf := C.TF_NewBuffer()
|
||||
defer C.TF_DeleteBuffer(metaGraphDefBuf)
|
||||
// TODO(jhseu): Add support for run_options and meta_graph_def.
|
||||
cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c)
|
||||
cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, metaGraphDefBuf, status.c)
|
||||
for i := range cTags {
|
||||
C.free(unsafe.Pointer(cTags[i]))
|
||||
}
|
||||
C.free(unsafe.Pointer(cExportDir))
|
||||
|
||||
metaGraphDefBytes := C.GoBytes(metaGraphDefBuf.data, C.int(metaGraphDefBuf.length))
|
||||
metaGraphDef := new(tfpb.MetaGraphDef)
|
||||
if err := proto.Unmarshal(metaGraphDefBytes, metaGraphDef); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signatures := generateSignatures(metaGraphDef.GetSignatureDef())
|
||||
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &Session{c: cSess}
|
||||
runtime.SetFinalizer(s, func(s *Session) { s.Close() })
|
||||
return &SavedModel{Session: s, Graph: graph}, nil
|
||||
return &SavedModel{Session: s, Graph: graph, Signatures: signatures}, nil
|
||||
}
|
||||
|
||||
func generateSignatures(pb map[string]*tfpb.SignatureDef) map[string]Signature {
|
||||
signatures := make(map[string]Signature)
|
||||
for name, signature := range pb {
|
||||
signatures[name] = signatureDefFromProto(signature)
|
||||
}
|
||||
return signatures
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ func TestSavedModel(t *testing.T) {
|
||||
if op := bundle.Graph.Operation("y"); op == nil {
|
||||
t.Fatalf("\"y\" not found in graph")
|
||||
}
|
||||
t.Logf("SavedModel: %+v", bundle)
|
||||
// TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
|
||||
// more thorough test when the generated protobufs are available.
|
||||
}
|
||||
|
119
tensorflow/go/signature.go
Normal file
119
tensorflow/go/signature.go
Normal file
@ -0,0 +1,119 @@
|
||||
/*
|
||||
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package tensorflow
|
||||
|
||||
import tfpb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework"
|
||||
|
||||
// #include "tensorflow/c/c_api.h"
|
||||
import "C"
|
||||
|
||||
// A Signature defines the signature of a computation supported by a TensorFlow
|
||||
// graph.
|
||||
//
|
||||
// For example, a model with two loss computations, sharing a single input,
|
||||
// might have the following signature_def map.
|
||||
//
|
||||
// Note that across the two Signatures "loss_A" and "loss_B", the input key,
|
||||
// output key, and method_name are identical, and will be used by system(s) that
|
||||
// implement or rely upon this particular loss method. The output tensor names
|
||||
// differ, demonstrating how different outputs can exist for the same method.
|
||||
//
|
||||
// signature_def {
|
||||
// key: "loss_A"
|
||||
// value {
|
||||
// inputs {
|
||||
// key: "input"
|
||||
// value {
|
||||
// name: "input:0"
|
||||
// dtype: DT_STRING
|
||||
// tensor_shape: ...
|
||||
// }
|
||||
// }
|
||||
// outputs {
|
||||
// key: "loss_output"
|
||||
// value {
|
||||
// name: "loss_output_A:0"
|
||||
// dtype: DT_FLOAT
|
||||
// tensor_shape: ...
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ...
|
||||
// method_name: "some/package/compute_loss"
|
||||
// }
|
||||
// signature_def {
|
||||
// key: "loss_B"
|
||||
// value {
|
||||
// inputs {
|
||||
// key: "input"
|
||||
// value {
|
||||
// name: "input:0"
|
||||
// dtype: DT_STRING
|
||||
// tensor_shape: ...
|
||||
// }
|
||||
// }
|
||||
// outputs {
|
||||
// key: "loss_output"
|
||||
// value {
|
||||
// name: "loss_output_B:0"
|
||||
// dtype: DT_FLOAT
|
||||
// tensor_shape: ...
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ...
|
||||
// method_name: "some/package/compute_loss"
|
||||
// }
|
||||
type Signature struct {
|
||||
Inputs, Outputs map[string]TensorInfo
|
||||
MethodName string
|
||||
}
|
||||
|
||||
// A TensorInfo contains the information about a Tensor necessary for feeding or retrieval.
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
DType DataType
|
||||
Shape Shape
|
||||
}
|
||||
|
||||
func signatureDefFromProto(pb *tfpb.SignatureDef) Signature {
|
||||
inputs := make(map[string]TensorInfo)
|
||||
for name, input := range pb.GetInputs() {
|
||||
inputs[name] = tensorInfoFromProto(input)
|
||||
}
|
||||
outputs := make(map[string]TensorInfo)
|
||||
for name, output := range pb.GetOutputs() {
|
||||
outputs[name] = tensorInfoFromProto(output)
|
||||
}
|
||||
return Signature{
|
||||
Inputs: inputs,
|
||||
Outputs: outputs,
|
||||
MethodName: pb.GetMethodName(),
|
||||
}
|
||||
}
|
||||
|
||||
func tensorInfoFromProto(pb *tfpb.TensorInfo) TensorInfo {
|
||||
var dims []int64
|
||||
for _, d := range pb.GetTensorShape().GetDim() {
|
||||
dims = append(dims, d.GetSize())
|
||||
}
|
||||
return TensorInfo{
|
||||
Name: pb.GetName(),
|
||||
DType: DataType(C.TF_DataType(pb.GetDtype())),
|
||||
Shape: MakeShape(dims...),
|
||||
}
|
||||
}
|
205
tensorflow/go/signature_test.go
Normal file
205
tensorflow/go/signature_test.go
Normal file
@ -0,0 +1,205 @@
|
||||
/*
|
||||
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
tfpb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework"
|
||||
)
|
||||
|
||||
func TestSignatureFromProto(t *testing.T) {
|
||||
got := signatureDefFromProto(&tfpb.SignatureDef{
|
||||
Inputs: map[string]*tfpb.TensorInfo{
|
||||
"input_1": &tfpb.TensorInfo{
|
||||
Encoding: &tfpb.TensorInfo_Name{
|
||||
Name: "tensor_1",
|
||||
},
|
||||
Dtype: tfpb.DataType_DT_INT8,
|
||||
TensorShape: &tfpb.TensorShapeProto{
|
||||
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||
{Size: 1},
|
||||
{Size: 2},
|
||||
{Size: 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
"input_2": &tfpb.TensorInfo{
|
||||
Encoding: &tfpb.TensorInfo_Name{
|
||||
Name: "tensor_2",
|
||||
},
|
||||
Dtype: tfpb.DataType_DT_FLOAT,
|
||||
TensorShape: &tfpb.TensorShapeProto{
|
||||
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||
{Size: 4},
|
||||
{Size: 5},
|
||||
{Size: 6},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Outputs: map[string]*tfpb.TensorInfo{
|
||||
"output_1": &tfpb.TensorInfo{
|
||||
Encoding: &tfpb.TensorInfo_Name{
|
||||
Name: "tensor_3",
|
||||
},
|
||||
Dtype: tfpb.DataType_DT_STRING,
|
||||
TensorShape: &tfpb.TensorShapeProto{
|
||||
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||
{Size: 1},
|
||||
{Size: 2},
|
||||
{Size: 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
"output_2": &tfpb.TensorInfo{
|
||||
Encoding: &tfpb.TensorInfo_Name{
|
||||
Name: "tensor_4",
|
||||
},
|
||||
Dtype: tfpb.DataType_DT_BOOL,
|
||||
TensorShape: &tfpb.TensorShapeProto{
|
||||
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||
{Size: 4},
|
||||
{Size: 5},
|
||||
{Size: 6},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
MethodName: "method",
|
||||
})
|
||||
|
||||
want := Signature{
|
||||
Inputs: map[string]TensorInfo{
|
||||
"input_1": TensorInfo{
|
||||
Name: "tensor_1",
|
||||
DType: Int8,
|
||||
Shape: MakeShape(1, 2, 3),
|
||||
},
|
||||
"input_2": TensorInfo{
|
||||
Name: "tensor_2",
|
||||
DType: Float,
|
||||
Shape: MakeShape(4, 5, 6),
|
||||
},
|
||||
},
|
||||
Outputs: map[string]TensorInfo{
|
||||
"output_1": TensorInfo{
|
||||
Name: "tensor_3",
|
||||
DType: String,
|
||||
Shape: MakeShape(1, 2, 3),
|
||||
},
|
||||
"output_2": TensorInfo{
|
||||
Name: "tensor_4",
|
||||
DType: Bool,
|
||||
Shape: MakeShape(4, 5, 6),
|
||||
},
|
||||
},
|
||||
MethodName: "method",
|
||||
}
|
||||
|
||||
for k, input := range want.Inputs {
|
||||
diff, err := diffTensorInfos(got.Inputs[k], input)
|
||||
if err != nil {
|
||||
t.Fatalf("Signature.Inputs[%s]: unable to diff TensorInfos: %v", k, err)
|
||||
}
|
||||
if diff != "" {
|
||||
t.Errorf("Signature.Inputs[%s] diff:\n%s", k, diff)
|
||||
}
|
||||
}
|
||||
|
||||
for k, output := range want.Outputs {
|
||||
diff, err := diffTensorInfos(got.Outputs[k], output)
|
||||
if err != nil {
|
||||
t.Fatalf("Signature.Outputs[%s]: unable to diff TensorInfos: %v", k, err)
|
||||
}
|
||||
if diff != "" {
|
||||
t.Errorf("Signature.Outputs[%s] diff:\n%s", k, diff)
|
||||
}
|
||||
}
|
||||
|
||||
if got.MethodName != want.MethodName {
|
||||
t.Errorf("Signature.MethodName: got %q, want %q", got.MethodName, want.MethodName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorInfoFromProto(t *testing.T) {
|
||||
got := tensorInfoFromProto(&tfpb.TensorInfo{
|
||||
Encoding: &tfpb.TensorInfo_Name{
|
||||
Name: "tensor",
|
||||
},
|
||||
Dtype: tfpb.DataType_DT_INT8,
|
||||
TensorShape: &tfpb.TensorShapeProto{
|
||||
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||
{Size: 1},
|
||||
{Size: 2},
|
||||
{Size: 3},
|
||||
},
|
||||
},
|
||||
})
|
||||
want := TensorInfo{
|
||||
Name: "tensor",
|
||||
DType: Int8,
|
||||
Shape: MakeShape(1, 2, 3),
|
||||
}
|
||||
|
||||
diff, err := diffTensorInfos(got, want)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to diff TensorInfos: %v", err)
|
||||
}
|
||||
if diff != "" {
|
||||
t.Errorf("tensorInfoFromProto produced a diff (got -> want): %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func diffTensorInfos(a, b TensorInfo) (string, error) {
|
||||
diff := ""
|
||||
if a.Name != b.Name {
|
||||
diff += fmt.Sprintf("Name: %q -> %q\n", a.Name, b.Name)
|
||||
}
|
||||
if a.DType != b.DType {
|
||||
diff += fmt.Sprintf("DType: %v -> %v\n", a.DType, b.DType)
|
||||
}
|
||||
|
||||
aShape, err := a.Shape.ToSlice()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
bShape, err := b.Shape.ToSlice()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
shapeLen := len(aShape)
|
||||
if len(bShape) > shapeLen {
|
||||
shapeLen = len(bShape)
|
||||
}
|
||||
for i := 0; i < shapeLen; i++ {
|
||||
if i >= len(aShape) {
|
||||
diff += fmt.Sprintf("+Shape[%d]: %d\n", i, bShape[i])
|
||||
continue
|
||||
}
|
||||
if i >= len(bShape) {
|
||||
diff += fmt.Sprintf("-Shape[%d]: %d\n", i, aShape[i])
|
||||
continue
|
||||
}
|
||||
if aShape[i] != bShape[i] {
|
||||
diff += fmt.Sprintf("Shape[%d]: %d -> %d\n", i, aShape[i], bShape[i])
|
||||
}
|
||||
}
|
||||
|
||||
return diff, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user