Fix handling of strings and bytes in example input

Without this fix, strings and bytes features will result in "TypeError: 'foo' has type str, but expected one of: bytes" and "ValueError: Type <class 'bytes'> for value b'foo' is not supported for tf.train.Feature.", respectively.

PiperOrigin-RevId: 336416629
Change-Id: I474472ab2f73330a29775a0aeeeb83b3b464ad18
This commit is contained in:
A. Unique TensorFlower 2020-10-09 21:46:24 -07:00 committed by TensorFlower Gardener
parent aebb9e6017
commit b4dcf68682
2 changed files with 47 additions and 19 deletions

View File

@ -590,6 +590,9 @@ def _create_example_string(example_dict):
example.features.feature[feature_name].float_list.value.extend(
feature_list)
elif isinstance(feature_list[0], str):
example.features.feature[feature_name].bytes_list.value.extend(
[f.encode('utf8') for f in feature_list])
elif isinstance(feature_list[0], bytes):
example.features.feature[feature_name].bytes_list.value.extend(
feature_list)
elif isinstance(feature_list[0], six.integer_types):

View File

@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for SavedModelCLI tool.
"""
"""Tests for SavedModelCLI tool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -29,6 +27,7 @@ from absl.testing import parameterized
import numpy as np
from six import StringIO
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.debug.wrappers import local_cli_wrapper
@ -177,8 +176,7 @@ signature_def['serving_default']:
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
dummy_model = DummyModel()
# Call with specific values to create new polymorphic function traces.
dummy_model.func1(
constant_op.constant(5), constant_op.constant(9), True)
dummy_model.func1(constant_op.constant(5), constant_op.constant(9), True)
dummy_model(constant_op.constant(5))
save.save(dummy_model, saved_model_dir)
self.parser = saved_model_cli.create_parser()
@ -391,6 +389,33 @@ Defined Functions:
self.assertTrue(len(input_dict) == 2)
self.assertTrue(len(input_expr_dict) == 2)
def testInputPreProcessExamplesWithStrAndBytes(self):
input_examples_str = 'inputs=[{"text":["foo"], "bytes":[b"bar"]}]'
input_dict = saved_model_cli.preprocess_input_examples_arg_string(
input_examples_str)
feature = example_pb2.Example.FromString(input_dict['inputs'][0])
self.assertProtoEquals(
"""
features {
feature {
key: "bytes"
value {
bytes_list {
value: "bar"
}
}
}
feature {
key: "text"
value {
bytes_list {
value: "foo"
}
}
}
}
""", feature)
def testInputPreProcessFileNames(self):
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
r'input:0=c:\PROGRA~1\data.npy')
@ -680,10 +705,11 @@ Defined Functions:
def fake_wrapper_session(sess):
return sess
with test.mock.patch.object(local_cli_wrapper,
'LocalCLIDebugWrapperSession',
side_effect=fake_wrapper_session,
autospec=True) as fake:
with test.mock.patch.object(
local_cli_wrapper,
'LocalCLIDebugWrapperSession',
side_effect=fake_wrapper_session,
autospec=True) as fake:
saved_model_cli.run(args)
fake.assert_called_with(test.mock.ANY)
@ -720,11 +746,11 @@ Defined Functions:
self.parser = saved_model_cli.create_parser()
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir')
args = self.parser.parse_args(
['aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
'--output_prefix', output_dir,
'--cpp_class', 'Compiled',
'--signature_def_key', 'MISSING'])
args = self.parser.parse_args([
'aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
'--output_prefix', output_dir, '--cpp_class', 'Compiled',
'--signature_def_key', 'MISSING'
])
with self.assertRaisesRegex(ValueError, 'Unable to find signature_def'):
saved_model_cli.aot_compile_cpu(args)
@ -785,9 +811,8 @@ Defined Functions:
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
args = self.parser.parse_args([
'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
'--signature_def_key', 'func',
'--output_prefix', output_prefix, '--variables_to_feed',
variables_to_feed, '--cpp_class', 'Generated'
'--signature_def_key', 'func', '--output_prefix', output_prefix,
'--variables_to_feed', variables_to_feed, '--cpp_class', 'Generated'
]) # Use the default seving signature_key.
with test.mock.patch.object(logging, 'warn') as captured_warn:
saved_model_cli.aot_compile_cpu(args)
@ -812,8 +837,8 @@ Defined Functions:
if func == dummy_model.func_write:
# Writeable variables setters do not preserve constness.
self.assertIn('set_var_param_write_var_data(float', header_contents)
self.assertNotIn(
'set_var_param_write_var_data(const float', header_contents)
self.assertNotIn('set_var_param_write_var_data(const float',
header_contents)
makefile_contents = file_io.read_file_to_string(
'{}_makefile.inc'.format(output_prefix))