Go: fix handling empty tags-set for loading saved model
This commit is contained in:
parent
2c5e22190c
commit
4a0c2fae4a
@ -17,6 +17,7 @@ limitations under the License.
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
||||
@ -57,6 +58,9 @@ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*
|
||||
return nil, err
|
||||
}
|
||||
cExportDir := C.CString(exportDir)
|
||||
if len(tags) == 0 {
|
||||
return nil, fmt.Errorf("empty tags are not allowed")
|
||||
}
|
||||
cTags := make([]*C.char, len(tags))
|
||||
for i := range tags {
|
||||
cTags[i] = C.CString(tags[i])
|
||||
|
@ -19,7 +19,8 @@ package tensorflow
|
||||
import "testing"
|
||||
|
||||
func TestSavedModel(t *testing.T) {
|
||||
bundle, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", []string{"serve"}, nil)
|
||||
tags := []string{"serve"}
|
||||
bundle, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", tags, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadSavedModel(): %v", err)
|
||||
}
|
||||
@ -30,3 +31,11 @@ func TestSavedModel(t *testing.T) {
|
||||
// TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
|
||||
// more thorough test when the generated protobufs are available.
|
||||
}
|
||||
|
||||
func TestSavedModelWithEmptyTags(t *testing.T) {
|
||||
tags := []string{}
|
||||
_, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", tags, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("LoadSavedModel() should return an error if tags are empty")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user