From 4911ca4a011c7b86087e4a420247fdc675ad8a88 Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Fri, 14 Feb 2020 15:47:49 -0800 Subject: [PATCH] Make interpreter python wrapper handle sparse tensor properly. PiperOrigin-RevId: 295244210 Change-Id: I75c76ac4974051325e285191464613500738951b --- tensorflow/lite/python/BUILD | 1 + tensorflow/lite/python/interpreter.py | 5 +- tensorflow/lite/python/interpreter_test.py | 37 +++++++++++ .../lite/python/interpreter_wrapper/BUILD | 1 + .../interpreter_wrapper.cc | 61 ++++++++++++++++++- .../interpreter_wrapper/interpreter_wrapper.h | 1 + .../lite/python/testdata/test_registerer.cc | 14 ++++- tensorflow/lite/util.cc | 12 ++-- 8 files changed, 121 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index f9910e08812..61e36aac4b7 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -29,6 +29,7 @@ py_test( name = "interpreter_test", srcs = ["interpreter_test.py"], data = [ + "//tensorflow/lite:testdata/sparse_tensor.bin", "//tensorflow/lite/python/testdata:interpreter_test_data", "//tensorflow/lite/python/testdata:test_delegate.so", ], diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index 4acedabeab9..0dad5e332ec 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -325,6 +325,8 @@ class Interpreter(object): tensor_quantization = self._interpreter.TensorQuantization(tensor_index) tensor_quantization_params = self._interpreter.TensorQuantizationParameters( tensor_index) + tensor_sparsity_params = self._interpreter.TensorSparsityParameters( + tensor_index) if not tensor_name or not tensor_type: raise ValueError('Could not get tensor details') @@ -340,7 +342,8 @@ class Interpreter(object): 'scales': tensor_quantization_params[0], 'zero_points': tensor_quantization_params[1], 'quantized_dimension': tensor_quantization_params[2], - } + }, + 'sparsity_parameters': tensor_sparsity_params } return details diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 9c8dbbaa9c2..122f5f2d04c 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -190,6 +190,43 @@ class InterpreterTest(test_util.TensorFlowTestCase): # Ensure that we retrieve per channel quantization params correctly. self.assertEqual(len(qparams['scales']), 128) + def testDenseTensorAccess(self): + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin')) + interpreter.allocate_tensors() + weight_details = interpreter.get_tensor_details()[1] + s_params = weight_details['sparsity_parameters'] + self.assertEqual(s_params, {}) + + def testSparseTensorAccess(self): + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_path=resource_loader.get_path_to_datafile( + '../testdata/sparse_tensor.bin'), + custom_op_registerers=['TF_TestRegisterer']) + interpreter.allocate_tensors() + + # Tensor at index 0 is sparse. + compressed_buffer = interpreter.get_tensor(0) + # Ensure that the buffer is of correct size and value. + self.assertEqual(len(compressed_buffer), 12) + sparse_value = [1, 0, 0, 4, 2, 3, 0, 0, 5, 0, 0, 6] + self.assertAllEqual(compressed_buffer, sparse_value) + + tensor_details = interpreter.get_tensor_details()[0] + s_params = tensor_details['sparsity_parameters'] + + # Ensure sparsity parameter returned is correct + self.assertAllEqual(s_params['traversal_order'], [0, 1, 2, 3]) + self.assertAllEqual(s_params['block_map'], [0, 1]) + dense_dim_metadata = {'format': 0, 'dense_size': 2} + self.assertAllEqual(s_params['dim_metadata'][0], dense_dim_metadata) + self.assertAllEqual(s_params['dim_metadata'][2], dense_dim_metadata) + self.assertAllEqual(s_params['dim_metadata'][3], dense_dim_metadata) + self.assertEqual(s_params['dim_metadata'][1]['format'], 1) + self.assertAllEqual(s_params['dim_metadata'][1]['array_segments'], + [0, 2, 3]) + self.assertAllEqual(s_params['dim_metadata'][1]['array_indices'], [0, 1, 1]) + class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index 9041f712d60..15dc9ec4376 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -29,6 +29,7 @@ cc_library( ":python_utils", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", + "//tensorflow/lite:util", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", "//tensorflow/lite/experimental/tflite_api_dispatcher", diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 372745bc479..9993d0211c2 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/util.h" #define TFLITE_PY_CHECK(x) \ if ((x) != kTfLiteOk) { \ @@ -110,6 +111,38 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { return result; } +PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) { + PyObject* result = PyDict_New(); + PyDict_SetItemString(result, "traversal_order", + PyArrayFromIntVector(param.traversal_order->data, + param.traversal_order->size)); + PyDict_SetItemString( + result, "block_map", + PyArrayFromIntVector(param.block_map->data, param.block_map->size)); + PyObject* dim_metadata = PyList_New(param.dim_metadata_size); + for (int i = 0; i < param.dim_metadata_size; i++) { + PyObject* dim_metadata_i = PyDict_New(); + if (param.dim_metadata[i].format == kTfLiteDimDense) { + PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(0)); + PyDict_SetItemString(dim_metadata_i, "dense_size", + PyLong_FromSize_t(param.dim_metadata[i].dense_size)); + } else { + PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(1)); + const auto* array_segments = param.dim_metadata[i].array_segments; + const auto* array_indices = param.dim_metadata[i].array_indices; + PyDict_SetItemString( + dim_metadata_i, "array_segments", + PyArrayFromIntVector(array_segments->data, array_segments->size)); + PyDict_SetItemString( + dim_metadata_i, "array_indices", + PyArrayFromIntVector(array_indices->data, array_indices->size)); + } + PyList_SetItem(dim_metadata, i, dim_metadata_i); + } + PyDict_SetItemString(result, "dim_metadata", dim_metadata); + return result; +} + bool RegisterCustomOpByName(const char* registerer_name, tflite::MutableOpResolver* resolver, std::string* error_msg) { @@ -322,6 +355,17 @@ PyObject* InterpreterWrapper::TensorSizeSignature(int i) const { return PyArray_Return(reinterpret_cast(np_array)); } +PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); + const TfLiteTensor* tensor = interpreter_->tensor(i); + if (tensor->sparsity == nullptr) { + return PyDict_New(); + } + + return PyDictFromSparsityParam(*tensor->sparsity); +} + PyObject* InterpreterWrapper::TensorQuantization(int i) const { TFLITE_PY_ENSURE_VALID_INTERPRETER(); TFLITE_PY_TENSOR_BOUNDS_CHECK(i); @@ -536,8 +580,21 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { return nullptr; } memcpy(data, tensor->data.raw, tensor->bytes); - PyObject* np_array = - PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); + PyObject* np_array; + if (tensor->sparsity == nullptr) { + np_array = + PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); + } else { + std::vector sparse_buffer_dims(1); + size_t size_of_type; + if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) { + PyErr_SetString(PyExc_ValueError, "Unknown tensor type."); + return nullptr; + } + sparse_buffer_dims[0] = tensor->bytes / size_of_type; + np_array = PyArray_SimpleNewFromData( + sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data); + } PyArray_ENABLEFLAGS(reinterpret_cast(np_array), NPY_ARRAY_OWNDATA); return PyArray_Return(reinterpret_cast(np_array)); diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index c37d3e998cd..8a5ff215f3a 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -70,6 +70,7 @@ class InterpreterWrapper { PyObject* TensorType(int i) const; PyObject* TensorSize(int i) const; PyObject* TensorSizeSignature(int i) const; + PyObject* TensorSparsityParameters(int i) const; // Deprecated in favor of TensorQuantizationScales, below. PyObject* TensorQuantization(int i) const; PyObject* TensorQuantizationParameters(int i) const; diff --git a/tensorflow/lite/python/testdata/test_registerer.cc b/tensorflow/lite/python/testdata/test_registerer.cc index 6adde65a863..8c4710a1902 100644 --- a/tensorflow/lite/python/testdata/test_registerer.cc +++ b/tensorflow/lite/python/testdata/test_registerer.cc @@ -18,12 +18,20 @@ namespace tflite { namespace { static int num_test_registerer_calls = 0; + +TfLiteRegistration* GetFakeRegistration() { + static TfLiteRegistration fake_op; + return &fake_op; +} + } // namespace -// Dummy registerer function with the correct signature. Ignores the resolver -// but increments the num_test_registerer_calls counter by one. The TF_ prefix -// is needed to get past the version script in the OSS build. +// Dummy registerer function with the correct signature. Registers a fake custom +// op needed by test models. Increments the num_test_registerer_calls counter by +// one. The TF_ prefix is needed to get past the version script in the OSS +// build. extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver) { + resolver->AddCustom("FakeOp", GetFakeRegistration()); num_test_registerer_calls++; } diff --git a/tensorflow/lite/util.cc b/tensorflow/lite/util.cc index 7a9470976f9..a876eebd639 100644 --- a/tensorflow/lite/util.cc +++ b/tensorflow/lite/util.cc @@ -103,11 +103,13 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type, *bytes = sizeof(TfLiteFloat16); break; default: - context->ReportError( - context, - "Type %d is unsupported. Only float32, int8, int16, int32, int64, " - "uint8, bool, complex64 supported currently.", - type); + if (context) { + context->ReportError( + context, + "Type %d is unsupported. Only float32, int8, int16, int32, int64, " + "uint8, bool, complex64 supported currently.", + type); + } return kTfLiteError; } return kTfLiteOk;