Make interpreter python wrapper handle sparse tensor properly.
PiperOrigin-RevId: 295244210 Change-Id: I75c76ac4974051325e285191464613500738951b
This commit is contained in:
parent
d6741e0994
commit
4911ca4a01
@ -29,6 +29,7 @@ py_test(
|
|||||||
name = "interpreter_test",
|
name = "interpreter_test",
|
||||||
srcs = ["interpreter_test.py"],
|
srcs = ["interpreter_test.py"],
|
||||||
data = [
|
data = [
|
||||||
|
"//tensorflow/lite:testdata/sparse_tensor.bin",
|
||||||
"//tensorflow/lite/python/testdata:interpreter_test_data",
|
"//tensorflow/lite/python/testdata:interpreter_test_data",
|
||||||
"//tensorflow/lite/python/testdata:test_delegate.so",
|
"//tensorflow/lite/python/testdata:test_delegate.so",
|
||||||
],
|
],
|
||||||
|
@ -325,6 +325,8 @@ class Interpreter(object):
|
|||||||
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
|
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
|
||||||
tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
|
tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
|
||||||
tensor_index)
|
tensor_index)
|
||||||
|
tensor_sparsity_params = self._interpreter.TensorSparsityParameters(
|
||||||
|
tensor_index)
|
||||||
|
|
||||||
if not tensor_name or not tensor_type:
|
if not tensor_name or not tensor_type:
|
||||||
raise ValueError('Could not get tensor details')
|
raise ValueError('Could not get tensor details')
|
||||||
@ -340,7 +342,8 @@ class Interpreter(object):
|
|||||||
'scales': tensor_quantization_params[0],
|
'scales': tensor_quantization_params[0],
|
||||||
'zero_points': tensor_quantization_params[1],
|
'zero_points': tensor_quantization_params[1],
|
||||||
'quantized_dimension': tensor_quantization_params[2],
|
'quantized_dimension': tensor_quantization_params[2],
|
||||||
}
|
},
|
||||||
|
'sparsity_parameters': tensor_sparsity_params
|
||||||
}
|
}
|
||||||
|
|
||||||
return details
|
return details
|
||||||
|
@ -190,6 +190,43 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||||||
# Ensure that we retrieve per channel quantization params correctly.
|
# Ensure that we retrieve per channel quantization params correctly.
|
||||||
self.assertEqual(len(qparams['scales']), 128)
|
self.assertEqual(len(qparams['scales']), 128)
|
||||||
|
|
||||||
|
def testDenseTensorAccess(self):
|
||||||
|
interpreter = interpreter_wrapper.Interpreter(
|
||||||
|
model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
weight_details = interpreter.get_tensor_details()[1]
|
||||||
|
s_params = weight_details['sparsity_parameters']
|
||||||
|
self.assertEqual(s_params, {})
|
||||||
|
|
||||||
|
def testSparseTensorAccess(self):
|
||||||
|
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
|
||||||
|
model_path=resource_loader.get_path_to_datafile(
|
||||||
|
'../testdata/sparse_tensor.bin'),
|
||||||
|
custom_op_registerers=['TF_TestRegisterer'])
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
# Tensor at index 0 is sparse.
|
||||||
|
compressed_buffer = interpreter.get_tensor(0)
|
||||||
|
# Ensure that the buffer is of correct size and value.
|
||||||
|
self.assertEqual(len(compressed_buffer), 12)
|
||||||
|
sparse_value = [1, 0, 0, 4, 2, 3, 0, 0, 5, 0, 0, 6]
|
||||||
|
self.assertAllEqual(compressed_buffer, sparse_value)
|
||||||
|
|
||||||
|
tensor_details = interpreter.get_tensor_details()[0]
|
||||||
|
s_params = tensor_details['sparsity_parameters']
|
||||||
|
|
||||||
|
# Ensure sparsity parameter returned is correct
|
||||||
|
self.assertAllEqual(s_params['traversal_order'], [0, 1, 2, 3])
|
||||||
|
self.assertAllEqual(s_params['block_map'], [0, 1])
|
||||||
|
dense_dim_metadata = {'format': 0, 'dense_size': 2}
|
||||||
|
self.assertAllEqual(s_params['dim_metadata'][0], dense_dim_metadata)
|
||||||
|
self.assertAllEqual(s_params['dim_metadata'][2], dense_dim_metadata)
|
||||||
|
self.assertAllEqual(s_params['dim_metadata'][3], dense_dim_metadata)
|
||||||
|
self.assertEqual(s_params['dim_metadata'][1]['format'], 1)
|
||||||
|
self.assertAllEqual(s_params['dim_metadata'][1]['array_segments'],
|
||||||
|
[0, 2, 3])
|
||||||
|
self.assertAllEqual(s_params['dim_metadata'][1]['array_indices'], [0, 1, 1])
|
||||||
|
|
||||||
|
|
||||||
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ cc_library(
|
|||||||
":python_utils",
|
":python_utils",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
|
"//tensorflow/lite:util",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/core/api",
|
"//tensorflow/lite/core/api",
|
||||||
"//tensorflow/lite/experimental/tflite_api_dispatcher",
|
"//tensorflow/lite/experimental/tflite_api_dispatcher",
|
||||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
#include "tensorflow/lite/util.h"
|
||||||
|
|
||||||
#define TFLITE_PY_CHECK(x) \
|
#define TFLITE_PY_CHECK(x) \
|
||||||
if ((x) != kTfLiteOk) { \
|
if ((x) != kTfLiteOk) { \
|
||||||
@ -110,6 +111,38 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) {
|
||||||
|
PyObject* result = PyDict_New();
|
||||||
|
PyDict_SetItemString(result, "traversal_order",
|
||||||
|
PyArrayFromIntVector(param.traversal_order->data,
|
||||||
|
param.traversal_order->size));
|
||||||
|
PyDict_SetItemString(
|
||||||
|
result, "block_map",
|
||||||
|
PyArrayFromIntVector(param.block_map->data, param.block_map->size));
|
||||||
|
PyObject* dim_metadata = PyList_New(param.dim_metadata_size);
|
||||||
|
for (int i = 0; i < param.dim_metadata_size; i++) {
|
||||||
|
PyObject* dim_metadata_i = PyDict_New();
|
||||||
|
if (param.dim_metadata[i].format == kTfLiteDimDense) {
|
||||||
|
PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(0));
|
||||||
|
PyDict_SetItemString(dim_metadata_i, "dense_size",
|
||||||
|
PyLong_FromSize_t(param.dim_metadata[i].dense_size));
|
||||||
|
} else {
|
||||||
|
PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(1));
|
||||||
|
const auto* array_segments = param.dim_metadata[i].array_segments;
|
||||||
|
const auto* array_indices = param.dim_metadata[i].array_indices;
|
||||||
|
PyDict_SetItemString(
|
||||||
|
dim_metadata_i, "array_segments",
|
||||||
|
PyArrayFromIntVector(array_segments->data, array_segments->size));
|
||||||
|
PyDict_SetItemString(
|
||||||
|
dim_metadata_i, "array_indices",
|
||||||
|
PyArrayFromIntVector(array_indices->data, array_indices->size));
|
||||||
|
}
|
||||||
|
PyList_SetItem(dim_metadata, i, dim_metadata_i);
|
||||||
|
}
|
||||||
|
PyDict_SetItemString(result, "dim_metadata", dim_metadata);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
bool RegisterCustomOpByName(const char* registerer_name,
|
bool RegisterCustomOpByName(const char* registerer_name,
|
||||||
tflite::MutableOpResolver* resolver,
|
tflite::MutableOpResolver* resolver,
|
||||||
std::string* error_msg) {
|
std::string* error_msg) {
|
||||||
@ -322,6 +355,17 @@ PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
|
|||||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const {
|
||||||
|
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||||
|
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||||
|
const TfLiteTensor* tensor = interpreter_->tensor(i);
|
||||||
|
if (tensor->sparsity == nullptr) {
|
||||||
|
return PyDict_New();
|
||||||
|
}
|
||||||
|
|
||||||
|
return PyDictFromSparsityParam(*tensor->sparsity);
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
|
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
|
||||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||||
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||||
@ -536,8 +580,21 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
memcpy(data, tensor->data.raw, tensor->bytes);
|
memcpy(data, tensor->data.raw, tensor->bytes);
|
||||||
PyObject* np_array =
|
PyObject* np_array;
|
||||||
PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
|
if (tensor->sparsity == nullptr) {
|
||||||
|
np_array =
|
||||||
|
PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
|
||||||
|
} else {
|
||||||
|
std::vector<npy_intp> sparse_buffer_dims(1);
|
||||||
|
size_t size_of_type;
|
||||||
|
if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
sparse_buffer_dims[0] = tensor->bytes / size_of_type;
|
||||||
|
np_array = PyArray_SimpleNewFromData(
|
||||||
|
sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data);
|
||||||
|
}
|
||||||
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
|
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
|
||||||
NPY_ARRAY_OWNDATA);
|
NPY_ARRAY_OWNDATA);
|
||||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||||
|
@ -70,6 +70,7 @@ class InterpreterWrapper {
|
|||||||
PyObject* TensorType(int i) const;
|
PyObject* TensorType(int i) const;
|
||||||
PyObject* TensorSize(int i) const;
|
PyObject* TensorSize(int i) const;
|
||||||
PyObject* TensorSizeSignature(int i) const;
|
PyObject* TensorSizeSignature(int i) const;
|
||||||
|
PyObject* TensorSparsityParameters(int i) const;
|
||||||
// Deprecated in favor of TensorQuantizationScales, below.
|
// Deprecated in favor of TensorQuantizationScales, below.
|
||||||
PyObject* TensorQuantization(int i) const;
|
PyObject* TensorQuantization(int i) const;
|
||||||
PyObject* TensorQuantizationParameters(int i) const;
|
PyObject* TensorQuantizationParameters(int i) const;
|
||||||
|
@ -18,12 +18,20 @@ namespace tflite {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
static int num_test_registerer_calls = 0;
|
static int num_test_registerer_calls = 0;
|
||||||
|
|
||||||
|
TfLiteRegistration* GetFakeRegistration() {
|
||||||
|
static TfLiteRegistration fake_op;
|
||||||
|
return &fake_op;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Dummy registerer function with the correct signature. Ignores the resolver
|
// Dummy registerer function with the correct signature. Registers a fake custom
|
||||||
// but increments the num_test_registerer_calls counter by one. The TF_ prefix
|
// op needed by test models. Increments the num_test_registerer_calls counter by
|
||||||
// is needed to get past the version script in the OSS build.
|
// one. The TF_ prefix is needed to get past the version script in the OSS
|
||||||
|
// build.
|
||||||
extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver) {
|
extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver) {
|
||||||
|
resolver->AddCustom("FakeOp", GetFakeRegistration());
|
||||||
num_test_registerer_calls++;
|
num_test_registerer_calls++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,11 +103,13 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
|
|||||||
*bytes = sizeof(TfLiteFloat16);
|
*bytes = sizeof(TfLiteFloat16);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(
|
if (context) {
|
||||||
context,
|
context->ReportError(
|
||||||
"Type %d is unsupported. Only float32, int8, int16, int32, int64, "
|
context,
|
||||||
"uint8, bool, complex64 supported currently.",
|
"Type %d is unsupported. Only float32, int8, int16, int32, int64, "
|
||||||
type);
|
"uint8, bool, complex64 supported currently.",
|
||||||
|
type);
|
||||||
|
}
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user