Add strict argument to resize_tensor_input in TFLite Python API.
PiperOrigin-RevId: 309875051 Change-Id: Ia246d6a8b6c087ddcad86b3ead76cd2af833ec59
This commit is contained in:
parent
6059623a34
commit
89072571ca
@ -406,13 +406,23 @@ class Interpreter(object):
|
||||
"""
|
||||
self._interpreter.SetTensor(tensor_index, value)
|
||||
|
||||
def resize_tensor_input(self, input_index, tensor_size):
|
||||
def resize_tensor_input(self, input_index, tensor_size, strict=False):
|
||||
"""Resizes an input tensor.
|
||||
|
||||
```
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.resize_tensor_input(0, [1, 224, 224, 3], strict=True)
|
||||
interpreter.allocate_tensors()
|
||||
interpreter.invoke()
|
||||
```
|
||||
|
||||
Args:
|
||||
input_index: Tensor index of input to set. This value can be gotten from
|
||||
the 'index' field in get_input_details.
|
||||
tensor_size: The tensor_shape to resize the input to.
|
||||
strict: Only unknown dimensions can be resized when `strict` is True.
|
||||
Unknown dimensions are indicated as `-1` in the `shape_signature`
|
||||
attribute of a given tensor. (default False)
|
||||
|
||||
Raises:
|
||||
ValueError: If the interpreter could not resize the input tensor.
|
||||
@ -421,7 +431,7 @@ class Interpreter(object):
|
||||
# `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size
|
||||
# parameter.
|
||||
tensor_size = np.array(tensor_size, dtype=np.int32)
|
||||
self._interpreter.ResizeInputTensor(input_index, tensor_size)
|
||||
self._interpreter.ResizeInputTensor(input_index, tensor_size, strict)
|
||||
|
||||
def get_output_details(self):
|
||||
"""Gets model output details.
|
||||
|
||||
@ -257,7 +257,7 @@ PyObject* InterpreterWrapper::OutputIndices() const {
|
||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
|
||||
PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) {
|
||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||
|
||||
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
|
||||
@ -282,10 +282,27 @@ PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array),
|
||||
NPY_ARRAY_OWNDATA);
|
||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(array));
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value,
|
||||
bool strict) {
|
||||
PyArrayObject* array =
|
||||
reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value));
|
||||
if (array == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> dims(PyArray_SHAPE(array)[0]);
|
||||
memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
|
||||
|
||||
TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
|
||||
if (strict) {
|
||||
TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(i, dims));
|
||||
} else {
|
||||
TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class InterpreterWrapper {
|
||||
|
||||
PyObject* InputIndices() const;
|
||||
PyObject* OutputIndices() const;
|
||||
PyObject* ResizeInputTensor(int i, PyObject* value);
|
||||
PyObject* ResizeInputTensor(int i, PyObject* value, bool strict);
|
||||
|
||||
int NumTensors() const;
|
||||
std::string TensorName(int i) const;
|
||||
@ -110,6 +110,9 @@ class InterpreterWrapper {
|
||||
InterpreterWrapper();
|
||||
InterpreterWrapper(const InterpreterWrapper& rhs);
|
||||
|
||||
// Helper function to resize an input tensor.
|
||||
PyObject* ResizeInputTensorImpl(int i, PyObject* value);
|
||||
|
||||
// The public functions which creates `InterpreterWrapper` should ensure all
|
||||
// these member variables are initialized successfully. Otherwise it should
|
||||
// report the error and return `nullptr`.
|
||||
|
||||
@ -70,9 +70,9 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
|
||||
return tensorflow::PyoOrThrow(self.OutputIndices());
|
||||
})
|
||||
.def("ResizeInputTensor",
|
||||
[](InterpreterWrapper& self, int i, py::handle& value) {
|
||||
[](InterpreterWrapper& self, int i, py::handle& value, bool strict) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
self.ResizeInputTensor(i, value.ptr()));
|
||||
self.ResizeInputTensor(i, value.ptr(), strict));
|
||||
})
|
||||
.def("NumTensors", &InterpreterWrapper::NumTensors)
|
||||
.def("TensorName", &InterpreterWrapper::TensorName)
|
||||
|
||||
@ -450,8 +450,15 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
3] == input_details[0]['shape_signature']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
# Resize tensor with strict checking.
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
|
||||
self.assertIn(
|
||||
'ResizeInputTensorStrict only allows mutating unknown dimensions '
|
||||
'identified by -1.', str(error.exception))
|
||||
|
||||
# Resize tensor and invoke.
|
||||
interpreter.resize_tensor_input(0, [1, 16, 16, 3])
|
||||
interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
|
||||
interpreter.allocate_tensors()
|
||||
interpreter.invoke()
|
||||
|
||||
@ -465,6 +472,34 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
self.assertTrue(([1, -1, 16,
|
||||
3] == output_details[0]['shape_signature']).all())
|
||||
|
||||
def testResizeTensorInputStrict(self):
|
||||
# Ensures that resize_tensor_input(strict=True) works as expected.
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||
out_tensor = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
|
||||
# Resize incorrect value.
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
|
||||
self.assertIn(
|
||||
'ResizeInputTensorStrict only allows mutating unknown dimensions '
|
||||
'identified by -1.', str(error.exception))
|
||||
|
||||
# Resize correct value.
|
||||
interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
def testBatchSizeValid(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
|
||||
@ -53,7 +53,7 @@ class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
for idx, (shape_signature, final_shape) in enumerate(input_shapes):
|
||||
self.assertTrue(
|
||||
(input_details[idx]['shape_signature'] == shape_signature).all())
|
||||
interpreter.resize_tensor_input(idx, final_shape)
|
||||
interpreter.resize_tensor_input(idx, final_shape, strict=True)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
|
||||
@ -36,7 +36,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "resize_tensor_input"
|
||||
argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'input_index\', \'tensor_size\', \'strict\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_tensor"
|
||||
|
||||
@ -36,7 +36,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "resize_tensor_input"
|
||||
argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'input_index\', \'tensor_size\', \'strict\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_tensor"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user