From 4a0c2fae4a53c2024228c8a942920ce1a3af8410 Mon Sep 17 00:00:00 2001 From: Alexander Bayandin Date: Tue, 4 Feb 2020 16:29:48 +0000 Subject: [PATCH] Go: fix handling empty tags-set for loading saved model --- tensorflow/go/saved_model.go | 4 ++++ tensorflow/go/saved_model_test.go | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tensorflow/go/saved_model.go b/tensorflow/go/saved_model.go index 0fbd6081ef2..7aa1e83cbc4 100644 --- a/tensorflow/go/saved_model.go +++ b/tensorflow/go/saved_model.go @@ -17,6 +17,7 @@ limitations under the License. package tensorflow import ( + "fmt" "runtime" "unsafe" @@ -57,6 +58,9 @@ func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (* return nil, err } cExportDir := C.CString(exportDir) + if len(tags) == 0 { + return nil, fmt.Errorf("empty tags are not allowed") + } cTags := make([]*C.char, len(tags)) for i := range tags { cTags[i] = C.CString(tags[i]) diff --git a/tensorflow/go/saved_model_test.go b/tensorflow/go/saved_model_test.go index 0ff448e5661..24811d692af 100644 --- a/tensorflow/go/saved_model_test.go +++ b/tensorflow/go/saved_model_test.go @@ -19,7 +19,8 @@ package tensorflow import "testing" func TestSavedModel(t *testing.T) { - bundle, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", []string{"serve"}, nil) + tags := []string{"serve"} + bundle, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", tags, nil) if err != nil { t.Fatalf("LoadSavedModel(): %v", err) } @@ -30,3 +31,11 @@ func TestSavedModel(t *testing.T) { // TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a // more thorough test when the generated protobufs are available. } + +func TestSavedModelWithEmptyTags(t *testing.T) { + tags := []string{} + _, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", tags, nil) + if err == nil { + t.Fatalf("LoadSavedModel() should return an error if tags are empty") + } +}