Add strict argument to resize_tensor_input in TFLite Python API.
PiperOrigin-RevId: 309875051 Change-Id: Ia246d6a8b6c087ddcad86b3ead76cd2af833ec59
This commit is contained in:
parent
6059623a34
commit
89072571ca
@ -406,13 +406,23 @@ class Interpreter(object):
|
|||||||
"""
|
"""
|
||||||
self._interpreter.SetTensor(tensor_index, value)
|
self._interpreter.SetTensor(tensor_index, value)
|
||||||
|
|
||||||
def resize_tensor_input(self, input_index, tensor_size):
|
def resize_tensor_input(self, input_index, tensor_size, strict=False):
|
||||||
"""Resizes an input tensor.
|
"""Resizes an input tensor.
|
||||||
|
|
||||||
|
```
|
||||||
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
|
interpreter.resize_tensor_input(0, [1, 224, 224, 3], strict=True)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
interpreter.invoke()
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_index: Tensor index of input to set. This value can be gotten from
|
input_index: Tensor index of input to set. This value can be gotten from
|
||||||
the 'index' field in get_input_details.
|
the 'index' field in get_input_details.
|
||||||
tensor_size: The tensor_shape to resize the input to.
|
tensor_size: The tensor_shape to resize the input to.
|
||||||
|
strict: Only unknown dimensions can be resized when `strict` is True.
|
||||||
|
Unknown dimensions are indicated as `-1` in the `shape_signature`
|
||||||
|
attribute of a given tensor. (default False)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the interpreter could not resize the input tensor.
|
ValueError: If the interpreter could not resize the input tensor.
|
||||||
@ -421,7 +431,7 @@ class Interpreter(object):
|
|||||||
# `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size
|
# `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size
|
||||||
# parameter.
|
# parameter.
|
||||||
tensor_size = np.array(tensor_size, dtype=np.int32)
|
tensor_size = np.array(tensor_size, dtype=np.int32)
|
||||||
self._interpreter.ResizeInputTensor(input_index, tensor_size)
|
self._interpreter.ResizeInputTensor(input_index, tensor_size, strict)
|
||||||
|
|
||||||
def get_output_details(self):
|
def get_output_details(self):
|
||||||
"""Gets model output details.
|
"""Gets model output details.
|
||||||
|
|||||||
@ -257,7 +257,7 @@ PyObject* InterpreterWrapper::OutputIndices() const {
|
|||||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
|
PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) {
|
||||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||||
|
|
||||||
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
|
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
|
||||||
@ -282,10 +282,27 @@ PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array),
|
||||||
|
NPY_ARRAY_OWNDATA);
|
||||||
|
return PyArray_Return(reinterpret_cast<PyArrayObject*>(array));
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value,
|
||||||
|
bool strict) {
|
||||||
|
PyArrayObject* array =
|
||||||
|
reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value));
|
||||||
|
if (array == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int> dims(PyArray_SHAPE(array)[0]);
|
std::vector<int> dims(PyArray_SHAPE(array)[0]);
|
||||||
memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
|
memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
|
||||||
|
|
||||||
TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
|
if (strict) {
|
||||||
|
TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(i, dims));
|
||||||
|
} else {
|
||||||
|
TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
|
||||||
|
}
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -63,7 +63,7 @@ class InterpreterWrapper {
|
|||||||
|
|
||||||
PyObject* InputIndices() const;
|
PyObject* InputIndices() const;
|
||||||
PyObject* OutputIndices() const;
|
PyObject* OutputIndices() const;
|
||||||
PyObject* ResizeInputTensor(int i, PyObject* value);
|
PyObject* ResizeInputTensor(int i, PyObject* value, bool strict);
|
||||||
|
|
||||||
int NumTensors() const;
|
int NumTensors() const;
|
||||||
std::string TensorName(int i) const;
|
std::string TensorName(int i) const;
|
||||||
@ -110,6 +110,9 @@ class InterpreterWrapper {
|
|||||||
InterpreterWrapper();
|
InterpreterWrapper();
|
||||||
InterpreterWrapper(const InterpreterWrapper& rhs);
|
InterpreterWrapper(const InterpreterWrapper& rhs);
|
||||||
|
|
||||||
|
// Helper function to resize an input tensor.
|
||||||
|
PyObject* ResizeInputTensorImpl(int i, PyObject* value);
|
||||||
|
|
||||||
// The public functions which creates `InterpreterWrapper` should ensure all
|
// The public functions which creates `InterpreterWrapper` should ensure all
|
||||||
// these member variables are initialized successfully. Otherwise it should
|
// these member variables are initialized successfully. Otherwise it should
|
||||||
// report the error and return `nullptr`.
|
// report the error and return `nullptr`.
|
||||||
|
|||||||
@ -70,9 +70,9 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
|
|||||||
return tensorflow::PyoOrThrow(self.OutputIndices());
|
return tensorflow::PyoOrThrow(self.OutputIndices());
|
||||||
})
|
})
|
||||||
.def("ResizeInputTensor",
|
.def("ResizeInputTensor",
|
||||||
[](InterpreterWrapper& self, int i, py::handle& value) {
|
[](InterpreterWrapper& self, int i, py::handle& value, bool strict) {
|
||||||
return tensorflow::PyoOrThrow(
|
return tensorflow::PyoOrThrow(
|
||||||
self.ResizeInputTensor(i, value.ptr()));
|
self.ResizeInputTensor(i, value.ptr(), strict));
|
||||||
})
|
})
|
||||||
.def("NumTensors", &InterpreterWrapper::NumTensors)
|
.def("NumTensors", &InterpreterWrapper::NumTensors)
|
||||||
.def("TensorName", &InterpreterWrapper::TensorName)
|
.def("TensorName", &InterpreterWrapper::TensorName)
|
||||||
|
|||||||
@ -450,8 +450,15 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
3] == input_details[0]['shape_signature']).all())
|
3] == input_details[0]['shape_signature']).all())
|
||||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||||
|
|
||||||
|
# Resize tensor with strict checking.
|
||||||
|
with self.assertRaises(RuntimeError) as error:
|
||||||
|
interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
|
||||||
|
self.assertIn(
|
||||||
|
'ResizeInputTensorStrict only allows mutating unknown dimensions '
|
||||||
|
'identified by -1.', str(error.exception))
|
||||||
|
|
||||||
# Resize tensor and invoke.
|
# Resize tensor and invoke.
|
||||||
interpreter.resize_tensor_input(0, [1, 16, 16, 3])
|
interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
interpreter.invoke()
|
interpreter.invoke()
|
||||||
|
|
||||||
@ -465,6 +472,34 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
self.assertTrue(([1, -1, 16,
|
self.assertTrue(([1, -1, 16,
|
||||||
3] == output_details[0]['shape_signature']).all())
|
3] == output_details[0]['shape_signature']).all())
|
||||||
|
|
||||||
|
def testResizeTensorInputStrict(self):
|
||||||
|
# Ensures that resize_tensor_input(strict=True) works as expected.
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
in_tensor = array_ops.placeholder(
|
||||||
|
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||||
|
out_tensor = in_tensor + in_tensor
|
||||||
|
sess = session.Session()
|
||||||
|
|
||||||
|
# Convert model and ensure model is not None.
|
||||||
|
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||||
|
[out_tensor])
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
|
# Check values from converted model.
|
||||||
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
|
|
||||||
|
# Resize incorrect value.
|
||||||
|
with self.assertRaises(RuntimeError) as error:
|
||||||
|
interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
|
||||||
|
self.assertIn(
|
||||||
|
'ResizeInputTensorStrict only allows mutating unknown dimensions '
|
||||||
|
'identified by -1.', str(error.exception))
|
||||||
|
|
||||||
|
# Resize correct value.
|
||||||
|
interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
def testBatchSizeValid(self):
|
def testBatchSizeValid(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
in_tensor = array_ops.placeholder(
|
in_tensor = array_ops.placeholder(
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
for idx, (shape_signature, final_shape) in enumerate(input_shapes):
|
for idx, (shape_signature, final_shape) in enumerate(input_shapes):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(input_details[idx]['shape_signature'] == shape_signature).all())
|
(input_details[idx]['shape_signature'] == shape_signature).all())
|
||||||
interpreter.resize_tensor_input(idx, final_shape)
|
interpreter.resize_tensor_input(idx, final_shape, strict=True)
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
|
|||||||
@ -36,7 +36,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "resize_tensor_input"
|
name: "resize_tensor_input"
|
||||||
argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'input_index\', \'tensor_size\', \'strict\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "set_tensor"
|
name: "set_tensor"
|
||||||
|
|||||||
@ -36,7 +36,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "resize_tensor_input"
|
name: "resize_tensor_input"
|
||||||
argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'input_index\', \'tensor_size\', \'strict\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "set_tensor"
|
name: "set_tensor"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user