Fix handling of strings and bytes in example input
Without this fix, strings and bytes features will result in "TypeError: 'foo' has type str, but expected one of: bytes" and "ValueError: Type <class 'bytes'> for value b'foo' is not supported for tf.train.Feature.", respectively. PiperOrigin-RevId: 336416629 Change-Id: I474472ab2f73330a29775a0aeeeb83b3b464ad18
This commit is contained in:
parent
aebb9e6017
commit
b4dcf68682
@ -590,6 +590,9 @@ def _create_example_string(example_dict):
|
|||||||
example.features.feature[feature_name].float_list.value.extend(
|
example.features.feature[feature_name].float_list.value.extend(
|
||||||
feature_list)
|
feature_list)
|
||||||
elif isinstance(feature_list[0], str):
|
elif isinstance(feature_list[0], str):
|
||||||
|
example.features.feature[feature_name].bytes_list.value.extend(
|
||||||
|
[f.encode('utf8') for f in feature_list])
|
||||||
|
elif isinstance(feature_list[0], bytes):
|
||||||
example.features.feature[feature_name].bytes_list.value.extend(
|
example.features.feature[feature_name].bytes_list.value.extend(
|
||||||
feature_list)
|
feature_list)
|
||||||
elif isinstance(feature_list[0], six.integer_types):
|
elif isinstance(feature_list[0], six.integer_types):
|
||||||
|
@ -12,9 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for SavedModelCLI tool.
|
"""Tests for SavedModelCLI tool."""
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -29,6 +27,7 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from six import StringIO
|
from six import StringIO
|
||||||
|
|
||||||
|
from tensorflow.core.example import example_pb2
|
||||||
from tensorflow.core.framework import types_pb2
|
from tensorflow.core.framework import types_pb2
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||||||
@ -177,8 +176,7 @@ signature_def['serving_default']:
|
|||||||
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
||||||
dummy_model = DummyModel()
|
dummy_model = DummyModel()
|
||||||
# Call with specific values to create new polymorphic function traces.
|
# Call with specific values to create new polymorphic function traces.
|
||||||
dummy_model.func1(
|
dummy_model.func1(constant_op.constant(5), constant_op.constant(9), True)
|
||||||
constant_op.constant(5), constant_op.constant(9), True)
|
|
||||||
dummy_model(constant_op.constant(5))
|
dummy_model(constant_op.constant(5))
|
||||||
save.save(dummy_model, saved_model_dir)
|
save.save(dummy_model, saved_model_dir)
|
||||||
self.parser = saved_model_cli.create_parser()
|
self.parser = saved_model_cli.create_parser()
|
||||||
@ -391,6 +389,33 @@ Defined Functions:
|
|||||||
self.assertTrue(len(input_dict) == 2)
|
self.assertTrue(len(input_dict) == 2)
|
||||||
self.assertTrue(len(input_expr_dict) == 2)
|
self.assertTrue(len(input_expr_dict) == 2)
|
||||||
|
|
||||||
|
def testInputPreProcessExamplesWithStrAndBytes(self):
|
||||||
|
input_examples_str = 'inputs=[{"text":["foo"], "bytes":[b"bar"]}]'
|
||||||
|
input_dict = saved_model_cli.preprocess_input_examples_arg_string(
|
||||||
|
input_examples_str)
|
||||||
|
feature = example_pb2.Example.FromString(input_dict['inputs'][0])
|
||||||
|
self.assertProtoEquals(
|
||||||
|
"""
|
||||||
|
features {
|
||||||
|
feature {
|
||||||
|
key: "bytes"
|
||||||
|
value {
|
||||||
|
bytes_list {
|
||||||
|
value: "bar"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
feature {
|
||||||
|
key: "text"
|
||||||
|
value {
|
||||||
|
bytes_list {
|
||||||
|
value: "foo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""", feature)
|
||||||
|
|
||||||
def testInputPreProcessFileNames(self):
|
def testInputPreProcessFileNames(self):
|
||||||
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
|
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
|
||||||
r'input:0=c:\PROGRA~1\data.npy')
|
r'input:0=c:\PROGRA~1\data.npy')
|
||||||
@ -680,7 +705,8 @@ Defined Functions:
|
|||||||
def fake_wrapper_session(sess):
|
def fake_wrapper_session(sess):
|
||||||
return sess
|
return sess
|
||||||
|
|
||||||
with test.mock.patch.object(local_cli_wrapper,
|
with test.mock.patch.object(
|
||||||
|
local_cli_wrapper,
|
||||||
'LocalCLIDebugWrapperSession',
|
'LocalCLIDebugWrapperSession',
|
||||||
side_effect=fake_wrapper_session,
|
side_effect=fake_wrapper_session,
|
||||||
autospec=True) as fake:
|
autospec=True) as fake:
|
||||||
@ -720,11 +746,11 @@ Defined Functions:
|
|||||||
self.parser = saved_model_cli.create_parser()
|
self.parser = saved_model_cli.create_parser()
|
||||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||||
output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir')
|
output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir')
|
||||||
args = self.parser.parse_args(
|
args = self.parser.parse_args([
|
||||||
['aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
|
'aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
|
||||||
'--output_prefix', output_dir,
|
'--output_prefix', output_dir, '--cpp_class', 'Compiled',
|
||||||
'--cpp_class', 'Compiled',
|
'--signature_def_key', 'MISSING'
|
||||||
'--signature_def_key', 'MISSING'])
|
])
|
||||||
with self.assertRaisesRegex(ValueError, 'Unable to find signature_def'):
|
with self.assertRaisesRegex(ValueError, 'Unable to find signature_def'):
|
||||||
saved_model_cli.aot_compile_cpu(args)
|
saved_model_cli.aot_compile_cpu(args)
|
||||||
|
|
||||||
@ -785,9 +811,8 @@ Defined Functions:
|
|||||||
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
|
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
|
||||||
args = self.parser.parse_args([
|
args = self.parser.parse_args([
|
||||||
'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
|
'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
|
||||||
'--signature_def_key', 'func',
|
'--signature_def_key', 'func', '--output_prefix', output_prefix,
|
||||||
'--output_prefix', output_prefix, '--variables_to_feed',
|
'--variables_to_feed', variables_to_feed, '--cpp_class', 'Generated'
|
||||||
variables_to_feed, '--cpp_class', 'Generated'
|
|
||||||
]) # Use the default seving signature_key.
|
]) # Use the default seving signature_key.
|
||||||
with test.mock.patch.object(logging, 'warn') as captured_warn:
|
with test.mock.patch.object(logging, 'warn') as captured_warn:
|
||||||
saved_model_cli.aot_compile_cpu(args)
|
saved_model_cli.aot_compile_cpu(args)
|
||||||
@ -812,8 +837,8 @@ Defined Functions:
|
|||||||
if func == dummy_model.func_write:
|
if func == dummy_model.func_write:
|
||||||
# Writeable variables setters do not preserve constness.
|
# Writeable variables setters do not preserve constness.
|
||||||
self.assertIn('set_var_param_write_var_data(float', header_contents)
|
self.assertIn('set_var_param_write_var_data(float', header_contents)
|
||||||
self.assertNotIn(
|
self.assertNotIn('set_var_param_write_var_data(const float',
|
||||||
'set_var_param_write_var_data(const float', header_contents)
|
header_contents)
|
||||||
|
|
||||||
makefile_contents = file_io.read_file_to_string(
|
makefile_contents = file_io.read_file_to_string(
|
||||||
'{}_makefile.inc'.format(output_prefix))
|
'{}_makefile.inc'.format(output_prefix))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user