Go: fix handling empty tags-set for loading saved model

This commit is contained in:
Alexander Bayandin 2020-02-04 16:29:48 +00:00
parent 2c5e22190c
commit 4a0c2fae4a
2 changed files with 14 additions and 1 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
package tensorflow
import (
"fmt"
"runtime"
"unsafe"
@ -57,6 +58,9 @@ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*
return nil, err
}
cExportDir := C.CString(exportDir)
if len(tags) == 0 {
return nil, fmt.Errorf("empty tags are not allowed")
}
cTags := make([]*C.char, len(tags))
for i := range tags {
cTags[i] = C.CString(tags[i])

View File

@ -19,7 +19,8 @@ package tensorflow
import "testing"
func TestSavedModel(t *testing.T) {
bundle, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", []string{"serve"}, nil)
tags := []string{"serve"}
bundle, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", tags, nil)
if err != nil {
t.Fatalf("LoadSavedModel(): %v", err)
}
@ -30,3 +31,11 @@ func TestSavedModel(t *testing.T) {
// TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
// more thorough test when the generated protobufs are available.
}
func TestSavedModelWithEmptyTags(t *testing.T) {
tags := []string{}
_, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", tags, nil)
if err == nil {
t.Fatalf("LoadSavedModel() should return an error if tags are empty")
}
}