From fe651f125271264aa13a415932fd7061621aba40 Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Thu, 15 Oct 2020 21:13:56 -0700 Subject: [PATCH] Fix a bug on setting string input with np.array Updated a logic for tf.lite.Interpreter.set_tensor() implementation. PiperOrigin-RevId: 337439640 Change-Id: I02b51139b9e75b3442188d96d7d7124ad6d3b68b --- tensorflow/lite/python/interpreter_test.py | 12 ++ .../lite/python/interpreter_wrapper/numpy.cc | 5 + tensorflow/lite/python/testdata/BUILD | 11 ++ .../lite/python/testdata/gather_0d.pbtxt | 108 ++++++++++++++++++ 4 files changed, 136 insertions(+) create mode 100644 tensorflow/lite/python/testdata/gather_0d.pbtxt diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index bcb338b84cf..62bd9710f23 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -218,6 +218,18 @@ class InterpreterTest(test_util.TensorFlowTestCase): output_data = interpreter.get_tensor(output_details[0]['index']) self.assertTrue((expected_output == output_data).all()) + def testStringZeroDim(self): + data = b'abcd' + bytes(16) + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/gather_string_0d.tflite')) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + interpreter.set_tensor(input_details[0]['index'], np.array(data)) + test_input_tensor = interpreter.get_tensor(input_details[0]['index']) + self.assertEqual(len(data), len(test_input_tensor.item(0))) + def testPerChannelParams(self): interpreter = interpreter_wrapper.Interpreter( model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin')) diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc index d2f308a74a2..b854e2ebd69 100644 --- a/tensorflow/lite/python/interpreter_wrapper/numpy.cc +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc @@ -153,6 +153,11 @@ bool FillStringBufferWithPyArray(PyObject* value, case NPY_OBJECT: case NPY_STRING: case NPY_UNICODE: { + if (PyArray_NDIM(array) == 0) { + dynamic_buffer->AddString(static_cast(PyArray_DATA(array)), + PyArray_NBYTES(array)); + return true; + } UniquePyObjectRef iter(PyArray_IterNew(value)); while (PyArray_ITER_NOTDONE(iter.get())) { UniquePyObjectRef item(PyArray_GETITEM( diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD index a1764d8fead..83f2d14666b 100644 --- a/tensorflow/lite/python/testdata/BUILD +++ b/tensorflow/lite/python/testdata/BUILD @@ -43,11 +43,22 @@ DEPRECATED_tf_to_tflite( ], ) +DEPRECATED_tf_to_tflite( + name = "gather_string_0d", + src = "gather_0d.pbtxt", + out = "gather_string_0d.tflite", + options = [ + "--input_arrays=input,indices", + "--output_arrays=output", + ], +) + filegroup( name = "interpreter_test_data", srcs = [ "pc_conv.bin", ":gather_string", + ":gather_string_0d", ":permute_float", ":permute_uint8", ], diff --git a/tensorflow/lite/python/testdata/gather_0d.pbtxt b/tensorflow/lite/python/testdata/gather_0d.pbtxt new file mode 100644 index 00000000000..b065cb22a4e --- /dev/null +++ b/tensorflow/lite/python/testdata/gather_0d.pbtxt @@ -0,0 +1,108 @@ +node { + name: "input" + op: "Placeholder" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_STRING + } + } +} +node { + name: "input_const" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "abcd" + } + } + } +} +node { + name: "indices" + op: "Placeholder" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + } + } + } +} +node { + name: "axis" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "output" + op: "GatherV2" + input: "input_const" + input: "indices" + input: "axis" + device: "/device:CPU:0" + attr { + key: "Taxis" + value { + type: DT_INT32 + } + } + attr { + key: "Tindices" + value { + type: DT_INT64 + } + } + attr { + key: "Tparams" + value { + type: DT_STRING + } + } +} +versions { + producer: 27 +}