Fix a bug on setting string input with np.array

Updated a logic for tf.lite.Interpreter.set_tensor() implementation.

PiperOrigin-RevId: 337439640
Change-Id: I02b51139b9e75b3442188d96d7d7124ad6d3b68b
This commit is contained in:
Terry Heo 2020-10-15 21:13:56 -07:00 committed by TensorFlower Gardener
parent 2bb30c589a
commit fe651f1252
4 changed files with 136 additions and 0 deletions

View File

@ -218,6 +218,18 @@ class InterpreterTest(test_util.TensorFlowTestCase):
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testStringZeroDim(self):
data = b'abcd' + bytes(16)
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/gather_string_0d.tflite'))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
interpreter.set_tensor(input_details[0]['index'], np.array(data))
test_input_tensor = interpreter.get_tensor(input_details[0]['index'])
self.assertEqual(len(data), len(test_input_tensor.item(0)))
def testPerChannelParams(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))

View File

@ -153,6 +153,11 @@ bool FillStringBufferWithPyArray(PyObject* value,
case NPY_OBJECT:
case NPY_STRING:
case NPY_UNICODE: {
if (PyArray_NDIM(array) == 0) {
dynamic_buffer->AddString(static_cast<char*>(PyArray_DATA(array)),
PyArray_NBYTES(array));
return true;
}
UniquePyObjectRef iter(PyArray_IterNew(value));
while (PyArray_ITER_NOTDONE(iter.get())) {
UniquePyObjectRef item(PyArray_GETITEM(

View File

@ -43,11 +43,22 @@ DEPRECATED_tf_to_tflite(
],
)
DEPRECATED_tf_to_tflite(
name = "gather_string_0d",
src = "gather_0d.pbtxt",
out = "gather_string_0d.tflite",
options = [
"--input_arrays=input,indices",
"--output_arrays=output",
],
)
filegroup(
name = "interpreter_test_data",
srcs = [
"pc_conv.bin",
":gather_string",
":gather_string_0d",
":permute_float",
":permute_uint8",
],

View File

@ -0,0 +1,108 @@
node {
name: "input"
op: "Placeholder"
device: "/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
}
node {
name: "input_const"
op: "Const"
device: "/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: "abcd"
}
}
}
}
node {
name: "indices"
op: "Placeholder"
device: "/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_INT64
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
}
}
}
}
node {
name: "axis"
op: "Const"
device: "/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "output"
op: "GatherV2"
input: "input_const"
input: "indices"
input: "axis"
device: "/device:CPU:0"
attr {
key: "Taxis"
value {
type: DT_INT32
}
}
attr {
key: "Tindices"
value {
type: DT_INT64
}
}
attr {
key: "Tparams"
value {
type: DT_STRING
}
}
}
versions {
producer: 27
}