Fix a bug on setting string input with np.array
Updated a logic for tf.lite.Interpreter.set_tensor() implementation. PiperOrigin-RevId: 337439640 Change-Id: I02b51139b9e75b3442188d96d7d7124ad6d3b68b
This commit is contained in:
parent
2bb30c589a
commit
fe651f1252
@ -218,6 +218,18 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
def testStringZeroDim(self):
|
||||
data = b'abcd' + bytes(16)
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/gather_string_0d.tflite'))
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
interpreter.set_tensor(input_details[0]['index'], np.array(data))
|
||||
test_input_tensor = interpreter.get_tensor(input_details[0]['index'])
|
||||
self.assertEqual(len(data), len(test_input_tensor.item(0)))
|
||||
|
||||
def testPerChannelParams(self):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))
|
||||
|
@ -153,6 +153,11 @@ bool FillStringBufferWithPyArray(PyObject* value,
|
||||
case NPY_OBJECT:
|
||||
case NPY_STRING:
|
||||
case NPY_UNICODE: {
|
||||
if (PyArray_NDIM(array) == 0) {
|
||||
dynamic_buffer->AddString(static_cast<char*>(PyArray_DATA(array)),
|
||||
PyArray_NBYTES(array));
|
||||
return true;
|
||||
}
|
||||
UniquePyObjectRef iter(PyArray_IterNew(value));
|
||||
while (PyArray_ITER_NOTDONE(iter.get())) {
|
||||
UniquePyObjectRef item(PyArray_GETITEM(
|
||||
|
11
tensorflow/lite/python/testdata/BUILD
vendored
11
tensorflow/lite/python/testdata/BUILD
vendored
@ -43,11 +43,22 @@ DEPRECATED_tf_to_tflite(
|
||||
],
|
||||
)
|
||||
|
||||
DEPRECATED_tf_to_tflite(
|
||||
name = "gather_string_0d",
|
||||
src = "gather_0d.pbtxt",
|
||||
out = "gather_string_0d.tflite",
|
||||
options = [
|
||||
"--input_arrays=input,indices",
|
||||
"--output_arrays=output",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "interpreter_test_data",
|
||||
srcs = [
|
||||
"pc_conv.bin",
|
||||
":gather_string",
|
||||
":gather_string_0d",
|
||||
":permute_float",
|
||||
":permute_uint8",
|
||||
],
|
||||
|
108
tensorflow/lite/python/testdata/gather_0d.pbtxt
vendored
Normal file
108
tensorflow/lite/python/testdata/gather_0d.pbtxt
vendored
Normal file
@ -0,0 +1,108 @@
|
||||
node {
|
||||
name: "input"
|
||||
op: "Placeholder"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input_const"
|
||||
op: "Const"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
string_val: "abcd"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "indices"
|
||||
op: "Placeholder"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "axis"
|
||||
op: "Const"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "output"
|
||||
op: "GatherV2"
|
||||
input: "input_const"
|
||||
input: "indices"
|
||||
input: "axis"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "Taxis"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tindices"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tparams"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user