Make interpreter python wrapper handle sparse tensor properly.

PiperOrigin-RevId: 295244210
Change-Id: I75c76ac4974051325e285191464613500738951b
This commit is contained in:
Yunlu Li 2020-02-14 15:47:49 -08:00 committed by TensorFlower Gardener
parent d6741e0994
commit 4911ca4a01
8 changed files with 121 additions and 11 deletions

View File

@ -29,6 +29,7 @@ py_test(
name = "interpreter_test",
srcs = ["interpreter_test.py"],
data = [
"//tensorflow/lite:testdata/sparse_tensor.bin",
"//tensorflow/lite/python/testdata:interpreter_test_data",
"//tensorflow/lite/python/testdata:test_delegate.so",
],

View File

@ -325,6 +325,8 @@ class Interpreter(object):
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
tensor_index)
tensor_sparsity_params = self._interpreter.TensorSparsityParameters(
tensor_index)
if not tensor_name or not tensor_type:
raise ValueError('Could not get tensor details')
@ -340,7 +342,8 @@ class Interpreter(object):
'scales': tensor_quantization_params[0],
'zero_points': tensor_quantization_params[1],
'quantized_dimension': tensor_quantization_params[2],
}
},
'sparsity_parameters': tensor_sparsity_params
}
return details

View File

@ -190,6 +190,43 @@ class InterpreterTest(test_util.TensorFlowTestCase):
# Ensure that we retrieve per channel quantization params correctly.
self.assertEqual(len(qparams['scales']), 128)
def testDenseTensorAccess(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))
interpreter.allocate_tensors()
weight_details = interpreter.get_tensor_details()[1]
s_params = weight_details['sparsity_parameters']
self.assertEqual(s_params, {})
def testSparseTensorAccess(self):
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile(
'../testdata/sparse_tensor.bin'),
custom_op_registerers=['TF_TestRegisterer'])
interpreter.allocate_tensors()
# Tensor at index 0 is sparse.
compressed_buffer = interpreter.get_tensor(0)
# Ensure that the buffer is of correct size and value.
self.assertEqual(len(compressed_buffer), 12)
sparse_value = [1, 0, 0, 4, 2, 3, 0, 0, 5, 0, 0, 6]
self.assertAllEqual(compressed_buffer, sparse_value)
tensor_details = interpreter.get_tensor_details()[0]
s_params = tensor_details['sparsity_parameters']
# Ensure sparsity parameter returned is correct
self.assertAllEqual(s_params['traversal_order'], [0, 1, 2, 3])
self.assertAllEqual(s_params['block_map'], [0, 1])
dense_dim_metadata = {'format': 0, 'dense_size': 2}
self.assertAllEqual(s_params['dim_metadata'][0], dense_dim_metadata)
self.assertAllEqual(s_params['dim_metadata'][2], dense_dim_metadata)
self.assertAllEqual(s_params['dim_metadata'][3], dense_dim_metadata)
self.assertEqual(s_params['dim_metadata'][1]['format'], 1)
self.assertAllEqual(s_params['dim_metadata'][1]['array_segments'],
[0, 2, 3])
self.assertAllEqual(s_params['dim_metadata'][1]['array_indices'], [0, 1, 1])
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):

View File

@ -29,6 +29,7 @@ cc_library(
":python_utils",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite:util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/experimental/tflite_api_dispatcher",

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/util.h"
#define TFLITE_PY_CHECK(x) \
if ((x) != kTfLiteOk) { \
@ -110,6 +111,38 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
return result;
}
PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) {
PyObject* result = PyDict_New();
PyDict_SetItemString(result, "traversal_order",
PyArrayFromIntVector(param.traversal_order->data,
param.traversal_order->size));
PyDict_SetItemString(
result, "block_map",
PyArrayFromIntVector(param.block_map->data, param.block_map->size));
PyObject* dim_metadata = PyList_New(param.dim_metadata_size);
for (int i = 0; i < param.dim_metadata_size; i++) {
PyObject* dim_metadata_i = PyDict_New();
if (param.dim_metadata[i].format == kTfLiteDimDense) {
PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(0));
PyDict_SetItemString(dim_metadata_i, "dense_size",
PyLong_FromSize_t(param.dim_metadata[i].dense_size));
} else {
PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(1));
const auto* array_segments = param.dim_metadata[i].array_segments;
const auto* array_indices = param.dim_metadata[i].array_indices;
PyDict_SetItemString(
dim_metadata_i, "array_segments",
PyArrayFromIntVector(array_segments->data, array_segments->size));
PyDict_SetItemString(
dim_metadata_i, "array_indices",
PyArrayFromIntVector(array_indices->data, array_indices->size));
}
PyList_SetItem(dim_metadata, i, dim_metadata_i);
}
PyDict_SetItemString(result, "dim_metadata", dim_metadata);
return result;
}
bool RegisterCustomOpByName(const char* registerer_name,
tflite::MutableOpResolver* resolver,
std::string* error_msg) {
@ -322,6 +355,17 @@ PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
if (tensor->sparsity == nullptr) {
return PyDict_New();
}
return PyDictFromSparsityParam(*tensor->sparsity);
}
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
@ -536,8 +580,21 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
return nullptr;
}
memcpy(data, tensor->data.raw, tensor->bytes);
PyObject* np_array =
PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
PyObject* np_array;
if (tensor->sparsity == nullptr) {
np_array =
PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
} else {
std::vector<npy_intp> sparse_buffer_dims(1);
size_t size_of_type;
if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) {
PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
return nullptr;
}
sparse_buffer_dims[0] = tensor->bytes / size_of_type;
np_array = PyArray_SimpleNewFromData(
sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data);
}
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
NPY_ARRAY_OWNDATA);
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));

View File

@ -70,6 +70,7 @@ class InterpreterWrapper {
PyObject* TensorType(int i) const;
PyObject* TensorSize(int i) const;
PyObject* TensorSizeSignature(int i) const;
PyObject* TensorSparsityParameters(int i) const;
// Deprecated in favor of TensorQuantizationScales, below.
PyObject* TensorQuantization(int i) const;
PyObject* TensorQuantizationParameters(int i) const;

View File

@ -18,12 +18,20 @@ namespace tflite {
namespace {
static int num_test_registerer_calls = 0;
TfLiteRegistration* GetFakeRegistration() {
static TfLiteRegistration fake_op;
return &fake_op;
}
} // namespace
// Dummy registerer function with the correct signature. Ignores the resolver
// but increments the num_test_registerer_calls counter by one. The TF_ prefix
// is needed to get past the version script in the OSS build.
// Dummy registerer function with the correct signature. Registers a fake custom
// op needed by test models. Increments the num_test_registerer_calls counter by
// one. The TF_ prefix is needed to get past the version script in the OSS
// build.
extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver) {
resolver->AddCustom("FakeOp", GetFakeRegistration());
num_test_registerer_calls++;
}

View File

@ -103,11 +103,13 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
*bytes = sizeof(TfLiteFloat16);
break;
default:
context->ReportError(
context,
"Type %d is unsupported. Only float32, int8, int16, int32, int64, "
"uint8, bool, complex64 supported currently.",
type);
if (context) {
context->ReportError(
context,
"Type %d is unsupported. Only float32, int8, int16, int32, int64, "
"uint8, bool, complex64 supported currently.",
type);
}
return kTfLiteError;
}
return kTfLiteOk;