Fix handling of strings and bytes in example input
Without this fix, strings and bytes features will result in "TypeError: 'foo' has type str, but expected one of: bytes" and "ValueError: Type <class 'bytes'> for value b'foo' is not supported for tf.train.Feature.", respectively. PiperOrigin-RevId: 336416629 Change-Id: I474472ab2f73330a29775a0aeeeb83b3b464ad18
This commit is contained in:
parent
aebb9e6017
commit
b4dcf68682
@ -590,6 +590,9 @@ def _create_example_string(example_dict):
|
||||
example.features.feature[feature_name].float_list.value.extend(
|
||||
feature_list)
|
||||
elif isinstance(feature_list[0], str):
|
||||
example.features.feature[feature_name].bytes_list.value.extend(
|
||||
[f.encode('utf8') for f in feature_list])
|
||||
elif isinstance(feature_list[0], bytes):
|
||||
example.features.feature[feature_name].bytes_list.value.extend(
|
||||
feature_list)
|
||||
elif isinstance(feature_list[0], six.integer_types):
|
||||
|
@ -12,9 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for SavedModelCLI tool.
|
||||
|
||||
"""
|
||||
"""Tests for SavedModelCLI tool."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -29,6 +27,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from six import StringIO
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||||
@ -177,8 +176,7 @@ signature_def['serving_default']:
|
||||
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
||||
dummy_model = DummyModel()
|
||||
# Call with specific values to create new polymorphic function traces.
|
||||
dummy_model.func1(
|
||||
constant_op.constant(5), constant_op.constant(9), True)
|
||||
dummy_model.func1(constant_op.constant(5), constant_op.constant(9), True)
|
||||
dummy_model(constant_op.constant(5))
|
||||
save.save(dummy_model, saved_model_dir)
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
@ -391,6 +389,33 @@ Defined Functions:
|
||||
self.assertTrue(len(input_dict) == 2)
|
||||
self.assertTrue(len(input_expr_dict) == 2)
|
||||
|
||||
def testInputPreProcessExamplesWithStrAndBytes(self):
|
||||
input_examples_str = 'inputs=[{"text":["foo"], "bytes":[b"bar"]}]'
|
||||
input_dict = saved_model_cli.preprocess_input_examples_arg_string(
|
||||
input_examples_str)
|
||||
feature = example_pb2.Example.FromString(input_dict['inputs'][0])
|
||||
self.assertProtoEquals(
|
||||
"""
|
||||
features {
|
||||
feature {
|
||||
key: "bytes"
|
||||
value {
|
||||
bytes_list {
|
||||
value: "bar"
|
||||
}
|
||||
}
|
||||
}
|
||||
feature {
|
||||
key: "text"
|
||||
value {
|
||||
bytes_list {
|
||||
value: "foo"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""", feature)
|
||||
|
||||
def testInputPreProcessFileNames(self):
|
||||
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
|
||||
r'input:0=c:\PROGRA~1\data.npy')
|
||||
@ -680,10 +705,11 @@ Defined Functions:
|
||||
def fake_wrapper_session(sess):
|
||||
return sess
|
||||
|
||||
with test.mock.patch.object(local_cli_wrapper,
|
||||
'LocalCLIDebugWrapperSession',
|
||||
side_effect=fake_wrapper_session,
|
||||
autospec=True) as fake:
|
||||
with test.mock.patch.object(
|
||||
local_cli_wrapper,
|
||||
'LocalCLIDebugWrapperSession',
|
||||
side_effect=fake_wrapper_session,
|
||||
autospec=True) as fake:
|
||||
saved_model_cli.run(args)
|
||||
fake.assert_called_with(test.mock.ANY)
|
||||
|
||||
@ -720,11 +746,11 @@ Defined Functions:
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir')
|
||||
args = self.parser.parse_args(
|
||||
['aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
|
||||
'--output_prefix', output_dir,
|
||||
'--cpp_class', 'Compiled',
|
||||
'--signature_def_key', 'MISSING'])
|
||||
args = self.parser.parse_args([
|
||||
'aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
|
||||
'--output_prefix', output_dir, '--cpp_class', 'Compiled',
|
||||
'--signature_def_key', 'MISSING'
|
||||
])
|
||||
with self.assertRaisesRegex(ValueError, 'Unable to find signature_def'):
|
||||
saved_model_cli.aot_compile_cpu(args)
|
||||
|
||||
@ -785,9 +811,8 @@ Defined Functions:
|
||||
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
|
||||
args = self.parser.parse_args([
|
||||
'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
|
||||
'--signature_def_key', 'func',
|
||||
'--output_prefix', output_prefix, '--variables_to_feed',
|
||||
variables_to_feed, '--cpp_class', 'Generated'
|
||||
'--signature_def_key', 'func', '--output_prefix', output_prefix,
|
||||
'--variables_to_feed', variables_to_feed, '--cpp_class', 'Generated'
|
||||
]) # Use the default seving signature_key.
|
||||
with test.mock.patch.object(logging, 'warn') as captured_warn:
|
||||
saved_model_cli.aot_compile_cpu(args)
|
||||
@ -812,8 +837,8 @@ Defined Functions:
|
||||
if func == dummy_model.func_write:
|
||||
# Writeable variables setters do not preserve constness.
|
||||
self.assertIn('set_var_param_write_var_data(float', header_contents)
|
||||
self.assertNotIn(
|
||||
'set_var_param_write_var_data(const float', header_contents)
|
||||
self.assertNotIn('set_var_param_write_var_data(const float',
|
||||
header_contents)
|
||||
|
||||
makefile_contents = file_io.read_file_to_string(
|
||||
'{}_makefile.inc'.format(output_prefix))
|
||||
|
Loading…
Reference in New Issue
Block a user