Add SignatureDefs to Go SavedModel API.
PiperOrigin-RevId: 281984017 Change-Id: Iefdf75ed88f54d97a0a7d210f5a42f3123205bf2
This commit is contained in:
parent
f0323528be
commit
f1a3c8af3c
@ -24,6 +24,8 @@ sh_test(
|
|||||||
"//tensorflow/c/eager:headers", # Eager C library header
|
"//tensorflow/c/eager:headers", # Eager C library header
|
||||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two", # Testdata for LoadSavedModel
|
"//tensorflow/cc/saved_model:saved_model_half_plus_two", # Testdata for LoadSavedModel
|
||||||
] + tf_shared_library_deps(),
|
] + tf_shared_library_deps(),
|
||||||
|
# TODO: Enable this test again once protos are supported by bazel.
|
||||||
|
tags = ["manual"],
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
@ -30,6 +30,7 @@ from source.
|
|||||||
sudo apt-get install python swig python-numpy # Linux
|
sudo apt-get install python swig python-numpy # Linux
|
||||||
brew install swig # OS X with homebrew
|
brew install swig # OS X with homebrew
|
||||||
```
|
```
|
||||||
|
- [Protocol buffer compiler (protoc) 3.x](https://github.com/google/protobuf/releases/)
|
||||||
|
|
||||||
### Build
|
### Build
|
||||||
|
|
||||||
@ -74,6 +75,7 @@ from source.
|
|||||||
4. Build and test:
|
4. Build and test:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
|
go generate github.com/tensorflow/tensorflow/tensorflow/go/op
|
||||||
go test github.com/tensorflow/tensorflow/tensorflow/go
|
go test github.com/tensorflow/tensorflow/tensorflow/go
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -19,6 +19,10 @@ package tensorflow
|
|||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
|
||||||
|
tfpb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework"
|
||||||
)
|
)
|
||||||
|
|
||||||
// #include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
@ -28,8 +32,9 @@ import "C"
|
|||||||
// SavedModel represents the contents of loaded SavedModel.
|
// SavedModel represents the contents of loaded SavedModel.
|
||||||
// TODO(jhseu): Add and document metagraphdef when we pregenerate protobufs.
|
// TODO(jhseu): Add and document metagraphdef when we pregenerate protobufs.
|
||||||
type SavedModel struct {
|
type SavedModel struct {
|
||||||
Session *Session
|
Session *Session
|
||||||
Graph *Graph
|
Graph *Graph
|
||||||
|
Signatures map[string]Signature
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSavedModel creates a new SavedModel from a model previously
|
// LoadSavedModel creates a new SavedModel from a model previously
|
||||||
@ -58,17 +63,35 @@ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*
|
|||||||
cTags[i] = C.CString(tags[i])
|
cTags[i] = C.CString(tags[i])
|
||||||
}
|
}
|
||||||
graph := NewGraph()
|
graph := NewGraph()
|
||||||
|
metaGraphDefBuf := C.TF_NewBuffer()
|
||||||
|
defer C.TF_DeleteBuffer(metaGraphDefBuf)
|
||||||
// TODO(jhseu): Add support for run_options and meta_graph_def.
|
// TODO(jhseu): Add support for run_options and meta_graph_def.
|
||||||
cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c)
|
cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, metaGraphDefBuf, status.c)
|
||||||
for i := range cTags {
|
for i := range cTags {
|
||||||
C.free(unsafe.Pointer(cTags[i]))
|
C.free(unsafe.Pointer(cTags[i]))
|
||||||
}
|
}
|
||||||
C.free(unsafe.Pointer(cExportDir))
|
C.free(unsafe.Pointer(cExportDir))
|
||||||
|
|
||||||
|
metaGraphDefBytes := C.GoBytes(metaGraphDefBuf.data, C.int(metaGraphDefBuf.length))
|
||||||
|
metaGraphDef := new(tfpb.MetaGraphDef)
|
||||||
|
if err := proto.Unmarshal(metaGraphDefBytes, metaGraphDef); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
signatures := generateSignatures(metaGraphDef.GetSignatureDef())
|
||||||
|
|
||||||
if err := status.Err(); err != nil {
|
if err := status.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s := &Session{c: cSess}
|
s := &Session{c: cSess}
|
||||||
runtime.SetFinalizer(s, func(s *Session) { s.Close() })
|
runtime.SetFinalizer(s, func(s *Session) { s.Close() })
|
||||||
return &SavedModel{Session: s, Graph: graph}, nil
|
return &SavedModel{Session: s, Graph: graph, Signatures: signatures}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSignatures(pb map[string]*tfpb.SignatureDef) map[string]Signature {
|
||||||
|
signatures := make(map[string]Signature)
|
||||||
|
for name, signature := range pb {
|
||||||
|
signatures[name] = signatureDefFromProto(signature)
|
||||||
|
}
|
||||||
|
return signatures
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,7 @@ func TestSavedModel(t *testing.T) {
|
|||||||
if op := bundle.Graph.Operation("y"); op == nil {
|
if op := bundle.Graph.Operation("y"); op == nil {
|
||||||
t.Fatalf("\"y\" not found in graph")
|
t.Fatalf("\"y\" not found in graph")
|
||||||
}
|
}
|
||||||
|
t.Logf("SavedModel: %+v", bundle)
|
||||||
// TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
|
// TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
|
||||||
// more thorough test when the generated protobufs are available.
|
// more thorough test when the generated protobufs are available.
|
||||||
}
|
}
|
||||||
|
119
tensorflow/go/signature.go
Normal file
119
tensorflow/go/signature.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package tensorflow
|
||||||
|
|
||||||
|
import tfpb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework"
|
||||||
|
|
||||||
|
// #include "tensorflow/c/c_api.h"
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
// A Signature defines the signature of a computation supported by a TensorFlow
|
||||||
|
// graph.
|
||||||
|
//
|
||||||
|
// For example, a model with two loss computations, sharing a single input,
|
||||||
|
// might have the following signature_def map.
|
||||||
|
//
|
||||||
|
// Note that across the two Signatures "loss_A" and "loss_B", the input key,
|
||||||
|
// output key, and method_name are identical, and will be used by system(s) that
|
||||||
|
// implement or rely upon this particular loss method. The output tensor names
|
||||||
|
// differ, demonstrating how different outputs can exist for the same method.
|
||||||
|
//
|
||||||
|
// signature_def {
|
||||||
|
// key: "loss_A"
|
||||||
|
// value {
|
||||||
|
// inputs {
|
||||||
|
// key: "input"
|
||||||
|
// value {
|
||||||
|
// name: "input:0"
|
||||||
|
// dtype: DT_STRING
|
||||||
|
// tensor_shape: ...
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// outputs {
|
||||||
|
// key: "loss_output"
|
||||||
|
// value {
|
||||||
|
// name: "loss_output_A:0"
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// tensor_shape: ...
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// ...
|
||||||
|
// method_name: "some/package/compute_loss"
|
||||||
|
// }
|
||||||
|
// signature_def {
|
||||||
|
// key: "loss_B"
|
||||||
|
// value {
|
||||||
|
// inputs {
|
||||||
|
// key: "input"
|
||||||
|
// value {
|
||||||
|
// name: "input:0"
|
||||||
|
// dtype: DT_STRING
|
||||||
|
// tensor_shape: ...
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// outputs {
|
||||||
|
// key: "loss_output"
|
||||||
|
// value {
|
||||||
|
// name: "loss_output_B:0"
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// tensor_shape: ...
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// ...
|
||||||
|
// method_name: "some/package/compute_loss"
|
||||||
|
// }
|
||||||
|
type Signature struct {
|
||||||
|
Inputs, Outputs map[string]TensorInfo
|
||||||
|
MethodName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// A TensorInfo contains the information about a Tensor necessary for feeding or retrieval.
|
||||||
|
type TensorInfo struct {
|
||||||
|
Name string
|
||||||
|
DType DataType
|
||||||
|
Shape Shape
|
||||||
|
}
|
||||||
|
|
||||||
|
func signatureDefFromProto(pb *tfpb.SignatureDef) Signature {
|
||||||
|
inputs := make(map[string]TensorInfo)
|
||||||
|
for name, input := range pb.GetInputs() {
|
||||||
|
inputs[name] = tensorInfoFromProto(input)
|
||||||
|
}
|
||||||
|
outputs := make(map[string]TensorInfo)
|
||||||
|
for name, output := range pb.GetOutputs() {
|
||||||
|
outputs[name] = tensorInfoFromProto(output)
|
||||||
|
}
|
||||||
|
return Signature{
|
||||||
|
Inputs: inputs,
|
||||||
|
Outputs: outputs,
|
||||||
|
MethodName: pb.GetMethodName(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func tensorInfoFromProto(pb *tfpb.TensorInfo) TensorInfo {
|
||||||
|
var dims []int64
|
||||||
|
for _, d := range pb.GetTensorShape().GetDim() {
|
||||||
|
dims = append(dims, d.GetSize())
|
||||||
|
}
|
||||||
|
return TensorInfo{
|
||||||
|
Name: pb.GetName(),
|
||||||
|
DType: DataType(C.TF_DataType(pb.GetDtype())),
|
||||||
|
Shape: MakeShape(dims...),
|
||||||
|
}
|
||||||
|
}
|
205
tensorflow/go/signature_test.go
Normal file
205
tensorflow/go/signature_test.go
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package tensorflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
tfpb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSignatureFromProto(t *testing.T) {
|
||||||
|
got := signatureDefFromProto(&tfpb.SignatureDef{
|
||||||
|
Inputs: map[string]*tfpb.TensorInfo{
|
||||||
|
"input_1": &tfpb.TensorInfo{
|
||||||
|
Encoding: &tfpb.TensorInfo_Name{
|
||||||
|
Name: "tensor_1",
|
||||||
|
},
|
||||||
|
Dtype: tfpb.DataType_DT_INT8,
|
||||||
|
TensorShape: &tfpb.TensorShapeProto{
|
||||||
|
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||||
|
{Size: 1},
|
||||||
|
{Size: 2},
|
||||||
|
{Size: 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"input_2": &tfpb.TensorInfo{
|
||||||
|
Encoding: &tfpb.TensorInfo_Name{
|
||||||
|
Name: "tensor_2",
|
||||||
|
},
|
||||||
|
Dtype: tfpb.DataType_DT_FLOAT,
|
||||||
|
TensorShape: &tfpb.TensorShapeProto{
|
||||||
|
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||||
|
{Size: 4},
|
||||||
|
{Size: 5},
|
||||||
|
{Size: 6},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Outputs: map[string]*tfpb.TensorInfo{
|
||||||
|
"output_1": &tfpb.TensorInfo{
|
||||||
|
Encoding: &tfpb.TensorInfo_Name{
|
||||||
|
Name: "tensor_3",
|
||||||
|
},
|
||||||
|
Dtype: tfpb.DataType_DT_STRING,
|
||||||
|
TensorShape: &tfpb.TensorShapeProto{
|
||||||
|
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||||
|
{Size: 1},
|
||||||
|
{Size: 2},
|
||||||
|
{Size: 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"output_2": &tfpb.TensorInfo{
|
||||||
|
Encoding: &tfpb.TensorInfo_Name{
|
||||||
|
Name: "tensor_4",
|
||||||
|
},
|
||||||
|
Dtype: tfpb.DataType_DT_BOOL,
|
||||||
|
TensorShape: &tfpb.TensorShapeProto{
|
||||||
|
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||||
|
{Size: 4},
|
||||||
|
{Size: 5},
|
||||||
|
{Size: 6},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MethodName: "method",
|
||||||
|
})
|
||||||
|
|
||||||
|
want := Signature{
|
||||||
|
Inputs: map[string]TensorInfo{
|
||||||
|
"input_1": TensorInfo{
|
||||||
|
Name: "tensor_1",
|
||||||
|
DType: Int8,
|
||||||
|
Shape: MakeShape(1, 2, 3),
|
||||||
|
},
|
||||||
|
"input_2": TensorInfo{
|
||||||
|
Name: "tensor_2",
|
||||||
|
DType: Float,
|
||||||
|
Shape: MakeShape(4, 5, 6),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Outputs: map[string]TensorInfo{
|
||||||
|
"output_1": TensorInfo{
|
||||||
|
Name: "tensor_3",
|
||||||
|
DType: String,
|
||||||
|
Shape: MakeShape(1, 2, 3),
|
||||||
|
},
|
||||||
|
"output_2": TensorInfo{
|
||||||
|
Name: "tensor_4",
|
||||||
|
DType: Bool,
|
||||||
|
Shape: MakeShape(4, 5, 6),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MethodName: "method",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, input := range want.Inputs {
|
||||||
|
diff, err := diffTensorInfos(got.Inputs[k], input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Signature.Inputs[%s]: unable to diff TensorInfos: %v", k, err)
|
||||||
|
}
|
||||||
|
if diff != "" {
|
||||||
|
t.Errorf("Signature.Inputs[%s] diff:\n%s", k, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, output := range want.Outputs {
|
||||||
|
diff, err := diffTensorInfos(got.Outputs[k], output)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Signature.Outputs[%s]: unable to diff TensorInfos: %v", k, err)
|
||||||
|
}
|
||||||
|
if diff != "" {
|
||||||
|
t.Errorf("Signature.Outputs[%s] diff:\n%s", k, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.MethodName != want.MethodName {
|
||||||
|
t.Errorf("Signature.MethodName: got %q, want %q", got.MethodName, want.MethodName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTensorInfoFromProto(t *testing.T) {
|
||||||
|
got := tensorInfoFromProto(&tfpb.TensorInfo{
|
||||||
|
Encoding: &tfpb.TensorInfo_Name{
|
||||||
|
Name: "tensor",
|
||||||
|
},
|
||||||
|
Dtype: tfpb.DataType_DT_INT8,
|
||||||
|
TensorShape: &tfpb.TensorShapeProto{
|
||||||
|
Dim: []*tfpb.TensorShapeProto_Dim{
|
||||||
|
{Size: 1},
|
||||||
|
{Size: 2},
|
||||||
|
{Size: 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
want := TensorInfo{
|
||||||
|
Name: "tensor",
|
||||||
|
DType: Int8,
|
||||||
|
Shape: MakeShape(1, 2, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
diff, err := diffTensorInfos(got, want)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to diff TensorInfos: %v", err)
|
||||||
|
}
|
||||||
|
if diff != "" {
|
||||||
|
t.Errorf("tensorInfoFromProto produced a diff (got -> want): %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func diffTensorInfos(a, b TensorInfo) (string, error) {
|
||||||
|
diff := ""
|
||||||
|
if a.Name != b.Name {
|
||||||
|
diff += fmt.Sprintf("Name: %q -> %q\n", a.Name, b.Name)
|
||||||
|
}
|
||||||
|
if a.DType != b.DType {
|
||||||
|
diff += fmt.Sprintf("DType: %v -> %v\n", a.DType, b.DType)
|
||||||
|
}
|
||||||
|
|
||||||
|
aShape, err := a.Shape.ToSlice()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
bShape, err := b.Shape.ToSlice()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
shapeLen := len(aShape)
|
||||||
|
if len(bShape) > shapeLen {
|
||||||
|
shapeLen = len(bShape)
|
||||||
|
}
|
||||||
|
for i := 0; i < shapeLen; i++ {
|
||||||
|
if i >= len(aShape) {
|
||||||
|
diff += fmt.Sprintf("+Shape[%d]: %d\n", i, bShape[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if i >= len(bShape) {
|
||||||
|
diff += fmt.Sprintf("-Shape[%d]: %d\n", i, aShape[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if aShape[i] != bShape[i] {
|
||||||
|
diff += fmt.Sprintf("Shape[%d]: %d -> %d\n", i, aShape[i], bShape[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return diff, nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user