Add strict argument to resize_tensor_input in TFLite Python API.

PiperOrigin-RevId: 309875051
Change-Id: Ia246d6a8b6c087ddcad86b3ead76cd2af833ec59
This commit is contained in:
Nupur Garg 2020-05-04 21:00:52 -07:00 committed by TensorFlower Gardener
parent 6059623a34
commit 89072571ca
8 changed files with 76 additions and 11 deletions

View File

@ -406,13 +406,23 @@ class Interpreter(object):
""" """
self._interpreter.SetTensor(tensor_index, value) self._interpreter.SetTensor(tensor_index, value)
def resize_tensor_input(self, input_index, tensor_size): def resize_tensor_input(self, input_index, tensor_size, strict=False):
"""Resizes an input tensor. """Resizes an input tensor.
```
interpreter = Interpreter(model_content=tflite_model)
interpreter.resize_tensor_input(0, [1, 224, 224, 3], strict=True)
interpreter.allocate_tensors()
interpreter.invoke()
```
Args: Args:
input_index: Tensor index of input to set. This value can be gotten from input_index: Tensor index of input to set. This value can be gotten from
the 'index' field in get_input_details. the 'index' field in get_input_details.
tensor_size: The tensor_shape to resize the input to. tensor_size: The tensor_shape to resize the input to.
strict: Only unknown dimensions can be resized when `strict` is True.
Unknown dimensions are indicated as `-1` in the `shape_signature`
attribute of a given tensor. (default False)
Raises: Raises:
ValueError: If the interpreter could not resize the input tensor. ValueError: If the interpreter could not resize the input tensor.
@ -421,7 +431,7 @@ class Interpreter(object):
# `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size
# parameter. # parameter.
tensor_size = np.array(tensor_size, dtype=np.int32) tensor_size = np.array(tensor_size, dtype=np.int32)
self._interpreter.ResizeInputTensor(input_index, tensor_size) self._interpreter.ResizeInputTensor(input_index, tensor_size, strict)
def get_output_details(self): def get_output_details(self):
"""Gets model output details. """Gets model output details.

View File

@ -257,7 +257,7 @@ PyObject* InterpreterWrapper::OutputIndices() const {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
} }
PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) {
TFLITE_PY_ENSURE_VALID_INTERPRETER(); TFLITE_PY_ENSURE_VALID_INTERPRETER();
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe( std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
@ -282,10 +282,27 @@ PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
return nullptr; return nullptr;
} }
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array),
NPY_ARRAY_OWNDATA);
return PyArray_Return(reinterpret_cast<PyArrayObject*>(array));
}
PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value,
bool strict) {
PyArrayObject* array =
reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value));
if (array == nullptr) {
return nullptr;
}
std::vector<int> dims(PyArray_SHAPE(array)[0]); std::vector<int> dims(PyArray_SHAPE(array)[0]);
memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int)); memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims)); if (strict) {
TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(i, dims));
} else {
TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
}
Py_RETURN_NONE; Py_RETURN_NONE;
} }

View File

@ -63,7 +63,7 @@ class InterpreterWrapper {
PyObject* InputIndices() const; PyObject* InputIndices() const;
PyObject* OutputIndices() const; PyObject* OutputIndices() const;
PyObject* ResizeInputTensor(int i, PyObject* value); PyObject* ResizeInputTensor(int i, PyObject* value, bool strict);
int NumTensors() const; int NumTensors() const;
std::string TensorName(int i) const; std::string TensorName(int i) const;
@ -110,6 +110,9 @@ class InterpreterWrapper {
InterpreterWrapper(); InterpreterWrapper();
InterpreterWrapper(const InterpreterWrapper& rhs); InterpreterWrapper(const InterpreterWrapper& rhs);
// Helper function to resize an input tensor.
PyObject* ResizeInputTensorImpl(int i, PyObject* value);
// The public functions which creates `InterpreterWrapper` should ensure all // The public functions which creates `InterpreterWrapper` should ensure all
// these member variables are initialized successfully. Otherwise it should // these member variables are initialized successfully. Otherwise it should
// report the error and return `nullptr`. // report the error and return `nullptr`.

View File

@ -70,9 +70,9 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
return tensorflow::PyoOrThrow(self.OutputIndices()); return tensorflow::PyoOrThrow(self.OutputIndices());
}) })
.def("ResizeInputTensor", .def("ResizeInputTensor",
[](InterpreterWrapper& self, int i, py::handle& value) { [](InterpreterWrapper& self, int i, py::handle& value, bool strict) {
return tensorflow::PyoOrThrow( return tensorflow::PyoOrThrow(
self.ResizeInputTensor(i, value.ptr())); self.ResizeInputTensor(i, value.ptr(), strict));
}) })
.def("NumTensors", &InterpreterWrapper::NumTensors) .def("NumTensors", &InterpreterWrapper::NumTensors)
.def("TensorName", &InterpreterWrapper::TensorName) .def("TensorName", &InterpreterWrapper::TensorName)

View File

@ -450,8 +450,15 @@ class FromSessionTest(TestModels, parameterized.TestCase):
3] == input_details[0]['shape_signature']).all()) 3] == input_details[0]['shape_signature']).all())
self.assertEqual((0., 0.), input_details[0]['quantization']) self.assertEqual((0., 0.), input_details[0]['quantization'])
# Resize tensor with strict checking.
with self.assertRaises(RuntimeError) as error:
interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
self.assertIn(
'ResizeInputTensorStrict only allows mutating unknown dimensions '
'identified by -1.', str(error.exception))
# Resize tensor and invoke. # Resize tensor and invoke.
interpreter.resize_tensor_input(0, [1, 16, 16, 3]) interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
interpreter.allocate_tensors() interpreter.allocate_tensors()
interpreter.invoke() interpreter.invoke()
@ -465,6 +472,34 @@ class FromSessionTest(TestModels, parameterized.TestCase):
self.assertTrue(([1, -1, 16, self.assertTrue(([1, -1, 16,
3] == output_details[0]['shape_signature']).all()) 3] == output_details[0]['shape_signature']).all())
def testResizeTensorInputStrict(self):
# Ensures that resize_tensor_input(strict=True) works as expected.
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
# Resize incorrect value.
with self.assertRaises(RuntimeError) as error:
interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
self.assertIn(
'ResizeInputTensorStrict only allows mutating unknown dimensions '
'identified by -1.', str(error.exception))
# Resize correct value.
interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
interpreter.allocate_tensors()
def testBatchSizeValid(self): def testBatchSizeValid(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
in_tensor = array_ops.placeholder( in_tensor = array_ops.placeholder(

View File

@ -53,7 +53,7 @@ class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
for idx, (shape_signature, final_shape) in enumerate(input_shapes): for idx, (shape_signature, final_shape) in enumerate(input_shapes):
self.assertTrue( self.assertTrue(
(input_details[idx]['shape_signature'] == shape_signature).all()) (input_details[idx]['shape_signature'] == shape_signature).all())
interpreter.resize_tensor_input(idx, final_shape) interpreter.resize_tensor_input(idx, final_shape, strict=True)
interpreter.allocate_tensors() interpreter.allocate_tensors()
output_details = interpreter.get_output_details() output_details = interpreter.get_output_details()

View File

@ -36,7 +36,7 @@ tf_class {
} }
member_method { member_method {
name: "resize_tensor_input" name: "resize_tensor_input"
argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'input_index\', \'tensor_size\', \'strict\'], varargs=None, keywords=None, defaults=[\'False\'], "
} }
member_method { member_method {
name: "set_tensor" name: "set_tensor"

View File

@ -36,7 +36,7 @@ tf_class {
} }
member_method { member_method {
name: "resize_tensor_input" name: "resize_tensor_input"
argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'input_index\', \'tensor_size\', \'strict\'], varargs=None, keywords=None, defaults=[\'False\'], "
} }
member_method { member_method {
name: "set_tensor" name: "set_tensor"