Go: fix handling empty tags-set for loading saved model

This commit is contained in:
Alexander Bayandin 2020-02-04 16:29:48 +00:00
parent 2c5e22190c
commit 4a0c2fae4a
2 changed files with 14 additions and 1 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
package tensorflow package tensorflow
import ( import (
"fmt"
"runtime" "runtime"
"unsafe" "unsafe"
@ -57,6 +58,9 @@ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*
return nil, err return nil, err
} }
cExportDir := C.CString(exportDir) cExportDir := C.CString(exportDir)
if len(tags) == 0 {
return nil, fmt.Errorf("empty tags are not allowed")
}
cTags := make([]*C.char, len(tags)) cTags := make([]*C.char, len(tags))
for i := range tags { for i := range tags {
cTags[i] = C.CString(tags[i]) cTags[i] = C.CString(tags[i])

View File

@ -19,7 +19,8 @@ package tensorflow
import "testing" import "testing"
func TestSavedModel(t *testing.T) { func TestSavedModel(t *testing.T) {
bundle, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", []string{"serve"}, nil) tags := []string{"serve"}
bundle, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", tags, nil)
if err != nil { if err != nil {
t.Fatalf("LoadSavedModel(): %v", err) t.Fatalf("LoadSavedModel(): %v", err)
} }
@ -30,3 +31,11 @@ func TestSavedModel(t *testing.T) {
// TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a // TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
// more thorough test when the generated protobufs are available. // more thorough test when the generated protobufs are available.
} }
func TestSavedModelWithEmptyTags(t *testing.T) {
tags := []string{}
_, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", tags, nil)
if err == nil {
t.Fatalf("LoadSavedModel() should return an error if tags are empty")
}
}