Introduce new Python API for accessing SignatureDefs in a TFLite model.
To run a specific SignatureDef use get_signature_runner(..) to get a SignatureRunner for running inference. The SignatureRunner returned is a callable object and can be called to invoke inference. Example, my_signature = interpreter.get_signature_runner("my_method") results = my_signature(input_1=input_tensor_1, input_2=input_tensor_2) print(results["my_output"]) To get the details about the available Signatures use interpreter.get_signature_list() Example, signatures = interpreter.get_signature_list() print(signatures) PiperOrigin-RevId: 345583496 Change-Id: Ie2a0a7e4e5676f06e98c82247cf4327534ce308e
This commit is contained in:
parent
293bd7502b
commit
95fe4c80e7
@ -53,7 +53,10 @@
|
||||
* Added dynamic range quantization support for the BatchMatMul op.
|
||||
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
|
||||
only supports float32 input.
|
||||
|
||||
* TFLite Supports SingatureDef:
|
||||
* TFLiteConverter exports models with SignatureDef
|
||||
* Interpreter supports getting a list of signatures and getting callable
|
||||
function for a given signaturedef.
|
||||
* TF Core:
|
||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||
|
@ -156,6 +156,81 @@ def load_delegate(library, options=None):
|
||||
return delegate
|
||||
|
||||
|
||||
class SignatureRunner(object):
|
||||
"""SignatureRunner class for running TFLite models using SignatureDef.
|
||||
|
||||
This class should be instantiated through TFLite Interpreter only using
|
||||
get_signature_runner method on Interpreter.
|
||||
Example,
|
||||
signature = interpreter.get_signature_runner("my_signature")
|
||||
result = signature(input_1=my_input_1, input_2=my_input_2)
|
||||
print(result["my_output"])
|
||||
print(result["my_second_output"])
|
||||
All names used are this specific SignatureDef names.
|
||||
|
||||
Notes:
|
||||
No other function on this object or on the interpreter provided should be
|
||||
called while this object call has not finished.
|
||||
"""
|
||||
|
||||
def __init__(self, interpreter=None, signature_def_name=None):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
interpreter: Interpreter object that is already initialized with the
|
||||
requested model.
|
||||
signature_def_name: SignatureDef names to be used.
|
||||
"""
|
||||
if not interpreter:
|
||||
raise ValueError('None interpreter provided.')
|
||||
if not signature_def_name:
|
||||
raise ValueError('None signature_def_name provided.')
|
||||
self._interpreter = interpreter
|
||||
self._signature_def_name = signature_def_name
|
||||
signature_defs = interpreter._get_full_signature_list()
|
||||
if signature_def_name not in signature_defs:
|
||||
raise ValueError('Invalid signature_def_name provided.')
|
||||
self._signature_def = signature_defs[signature_def_name]
|
||||
self._outputs = self._signature_def['outputs'].items()
|
||||
self._inputs = self._signature_def['inputs']
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
"""Runs the SignatureDef given the provided inputs in arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: key,value for inputs to the model. Key is the SignatureDef input
|
||||
name. Value is numpy array with the value.
|
||||
|
||||
Returns:
|
||||
dictionary of the results from the model invoke.
|
||||
Key in the dictionary is SignatureDef output name.
|
||||
Value is the result Tensor.
|
||||
"""
|
||||
|
||||
if len(kwargs) != len(self._inputs):
|
||||
raise ValueError(
|
||||
'Invalid number of inputs provided for running a SignatureDef, '
|
||||
'expected %s vs provided %s' % (len(kwargs), len(self._inputs)))
|
||||
# Resize input tensors
|
||||
for input_name, value in kwargs.items():
|
||||
if input_name not in self._inputs:
|
||||
raise ValueError('Invalid Input name (%s) for SignatureDef' %
|
||||
input_name)
|
||||
self._interpreter.resize_tensor_input(self._inputs[input_name],
|
||||
value.shape)
|
||||
# Allocate tensors.
|
||||
self._interpreter.allocate_tensors()
|
||||
# Set the input values.
|
||||
for input_name, value in kwargs.items():
|
||||
self._interpreter._set_input_tensor(
|
||||
input_name, value=value, method_name=self._signature_def_name)
|
||||
self._interpreter.invoke()
|
||||
result = {}
|
||||
for output_name, output_index in self._outputs:
|
||||
result[output_name] = self._interpreter.get_tensor(output_index)
|
||||
return result
|
||||
|
||||
|
||||
@_tf_export('lite.Interpreter')
|
||||
class Interpreter(object):
|
||||
"""Interpreter interface for TensorFlow Lite Models.
|
||||
@ -244,6 +319,7 @@ class Interpreter(object):
|
||||
for delegate in self._delegates:
|
||||
self._interpreter.ModifyGraphWithDelegate(
|
||||
delegate._get_native_delegate_pointer()) # pylint: disable=protected-access
|
||||
self._signature_defs = self.get_signature_list()
|
||||
|
||||
def __del__(self):
|
||||
# Must make sure the interpreter is destroyed before things that
|
||||
@ -461,6 +537,148 @@ class Interpreter(object):
|
||||
self._get_tensor_details(i) for i in self._interpreter.OutputIndices()
|
||||
]
|
||||
|
||||
def get_signature_list(self):
|
||||
"""Gets list of SignatureDefs in the model.
|
||||
|
||||
Example,
|
||||
```
|
||||
signatures = interpreter.get_signature_list()
|
||||
print(signatures)
|
||||
|
||||
# {
|
||||
# 'add': {'inputs': ['x', 'y'], 'outputs': ['output_0']}
|
||||
# }
|
||||
|
||||
Then using the names in the signature list you can get a callable from
|
||||
get_signature_runner().
|
||||
```
|
||||
|
||||
Returns:
|
||||
A list of SignatureDef details in a dictionary structure.
|
||||
It is keyed on the SignatureDef method name, and the value holds
|
||||
dictionary of inputs and outputs.
|
||||
"""
|
||||
full_signature_defs = self._interpreter.GetSignatureDefs()
|
||||
for _, signature_def in full_signature_defs.items():
|
||||
signature_def['inputs'] = list(signature_def['inputs'].keys())
|
||||
signature_def['outputs'] = list(signature_def['outputs'].keys())
|
||||
return full_signature_defs
|
||||
|
||||
def _get_full_signature_list(self):
|
||||
"""Gets list of SignatureDefs in the model.
|
||||
|
||||
Example,
|
||||
```
|
||||
signatures = interpreter._get_full_signature_list()
|
||||
print(signatures)
|
||||
|
||||
# {
|
||||
# 'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}}
|
||||
# }
|
||||
|
||||
Then using the names in the signature list you can get a callable from
|
||||
get_signature_runner().
|
||||
```
|
||||
|
||||
Returns:
|
||||
A list of SignatureDef details in a dictionary structure.
|
||||
It is keyed on the SignatureDef method name, and the value holds
|
||||
dictionary of inputs and outputs.
|
||||
"""
|
||||
return self._interpreter.GetSignatureDefs()
|
||||
|
||||
def _set_input_tensor(self, input_name, value, method_name=None):
|
||||
"""Sets the value of the input tensor.
|
||||
|
||||
Input tensor is identified by `input_name` in the SignatureDef identified
|
||||
by `method_name`.
|
||||
If the model has a single SignatureDef then you can pass None as
|
||||
`method_name`.
|
||||
|
||||
Note this copies data in `value`.
|
||||
|
||||
Example,
|
||||
```
|
||||
input_data = np.array([1.2, 1.4], np.float32)
|
||||
signatures = interpreter.get_signature_list()
|
||||
print(signatures)
|
||||
# {
|
||||
# 'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}}
|
||||
# }
|
||||
interpreter._set_input_tensor(input_name='x', value=input_data,
|
||||
method_name='add_fn')
|
||||
```
|
||||
|
||||
Args:
|
||||
input_name: Name of the output tensor in the SignatureDef.
|
||||
value: Value of tensor to set as a numpy array.
|
||||
method_name: The exported method name for the SignatureDef, it can be None
|
||||
if and only if the model has a single SignatureDef. Default value is
|
||||
None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the interpreter could not set the tensor. Or
|
||||
if `method_name` is None and model doesn't have a single
|
||||
Signature.
|
||||
"""
|
||||
if method_name is None:
|
||||
if len(self._signature_defs) != 1:
|
||||
raise ValueError(
|
||||
'SignatureDef method_name is None and model has {0} Signatures. '
|
||||
'None is only allowed when the model has 1 SignatureDef'.format(
|
||||
len(self._signature_defs)))
|
||||
else:
|
||||
method_name = next(iter(self._signature_defs))
|
||||
self._interpreter.SetInputTensorFromSignatureDefName(
|
||||
input_name, method_name, value)
|
||||
|
||||
def get_signature_runner(self, method_name=None):
|
||||
"""Gets callable for inference of specific SignatureDef.
|
||||
|
||||
Example usage,
|
||||
```
|
||||
interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
fn = interpreter.get_signature_runner('div_with_remainder')
|
||||
output = fn(x=np.array([3]), y=np.array([2]))
|
||||
print(output)
|
||||
# {
|
||||
# 'quotient': array([1.], dtype=float32)
|
||||
# 'remainder': array([1.], dtype=float32)
|
||||
# }
|
||||
```
|
||||
|
||||
None can be passed for method_name if the model has a single Signature only.
|
||||
|
||||
All names used are this specific SignatureDef names.
|
||||
|
||||
|
||||
Args:
|
||||
method_name: The exported method name for the SignatureDef, it can be None
|
||||
if and only if the model has a single SignatureDef. Default value is
|
||||
None.
|
||||
|
||||
Returns:
|
||||
This returns a callable that can run inference for SignatureDef defined
|
||||
by argument 'method_name'.
|
||||
The callable will take key arguments corresponding to the arguments of the
|
||||
SignatureDef, that should have numpy values.
|
||||
The callable will returns dictionary that maps from output names to numpy
|
||||
values of the computed results.
|
||||
|
||||
Raises:
|
||||
ValueError: If passed method_name is invalid.
|
||||
"""
|
||||
if method_name is None:
|
||||
if len(self._signature_defs) != 1:
|
||||
raise ValueError(
|
||||
'SignatureDef method_name is None and model has {0} Signatures. '
|
||||
'None is only allowed when the model has 1 SignatureDef'.format(
|
||||
len(self._signature_defs)))
|
||||
else:
|
||||
method_name = next(iter(self._signature_defs))
|
||||
return SignatureRunner(interpreter=self, signature_def_name=method_name)
|
||||
|
||||
def get_tensor(self, tensor_index):
|
||||
"""Gets the value of the output tensor (get a copy).
|
||||
|
||||
|
@ -565,6 +565,48 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
|
||||
|
||||
} // namespace
|
||||
|
||||
PyObject* InterpreterWrapper::GetSignatureDefs() const {
|
||||
PyObject* result = PyDict_New();
|
||||
for (const auto& sig_def_name : interpreter_->signature_def_names()) {
|
||||
PyObject* signature_def = PyDict_New();
|
||||
PyObject* inputs = PyDict_New();
|
||||
PyObject* outputs = PyDict_New();
|
||||
const auto& signature_def_inputs =
|
||||
interpreter_->signature_inputs(sig_def_name->c_str());
|
||||
const auto& signature_def_outputs =
|
||||
interpreter_->signature_outputs(sig_def_name->c_str());
|
||||
for (const auto& input : signature_def_inputs) {
|
||||
PyDict_SetItemString(inputs, input.first.c_str(),
|
||||
PyLong_FromLong(input.second));
|
||||
}
|
||||
for (const auto& output : signature_def_outputs) {
|
||||
PyDict_SetItemString(outputs, output.first.c_str(),
|
||||
PyLong_FromLong(output.second));
|
||||
}
|
||||
|
||||
PyDict_SetItemString(signature_def, "inputs", inputs);
|
||||
PyDict_SetItemString(signature_def, "outputs", outputs);
|
||||
PyDict_SetItemString(result, sig_def_name->c_str(), signature_def);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::GetOutputTensorFromSignatureDefName(
|
||||
const char* output_name, const char* method_name) const {
|
||||
const auto& outputs = interpreter_->signature_outputs(method_name);
|
||||
const auto& output = outputs.find(output_name);
|
||||
if (output == outputs.end()) return nullptr;
|
||||
return GetTensor(output->second);
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::SetInputTensorFromSignatureDefName(
|
||||
const char* input_name, const char* method_name, PyObject* value) {
|
||||
const auto& inputs = interpreter_->signature_inputs(method_name);
|
||||
const auto& input = inputs.find(input_name);
|
||||
if (input == inputs.end()) return nullptr;
|
||||
return SetTensor(input->second, value);
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::GetTensor(int i) const {
|
||||
// Sanity check accessor
|
||||
TfLiteTensor* tensor = nullptr;
|
||||
|
@ -87,6 +87,12 @@ class InterpreterWrapper {
|
||||
PyObject* TensorQuantizationParameters(int i) const;
|
||||
PyObject* SetTensor(int i, PyObject* value);
|
||||
PyObject* GetTensor(int i) const;
|
||||
PyObject* SetInputTensorFromSignatureDefName(const char* input_name,
|
||||
const char* method_name,
|
||||
PyObject* value);
|
||||
PyObject* GetOutputTensorFromSignatureDefName(const char* output_name,
|
||||
const char* method_name) const;
|
||||
PyObject* GetSignatureDefs() const;
|
||||
PyObject* ResetVariableTensors();
|
||||
|
||||
int NumNodes() const;
|
||||
|
@ -141,6 +141,24 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
|
||||
[](const InterpreterWrapper& self, int i) {
|
||||
return tensorflow::PyoOrThrow(self.GetTensor(i));
|
||||
})
|
||||
.def("SetInputTensorFromSignatureDefName",
|
||||
[](InterpreterWrapper& self, const char* input_name,
|
||||
const char* method_name, py::handle& value) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
self.SetInputTensorFromSignatureDefName(
|
||||
input_name, method_name, value.ptr()));
|
||||
})
|
||||
.def("GetOutputTensorFromSignatureDefName",
|
||||
[](const InterpreterWrapper& self, const char* output_name,
|
||||
const char* method_name) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
self.GetOutputTensorFromSignatureDefName(output_name,
|
||||
method_name));
|
||||
})
|
||||
.def("GetSignatureDefs",
|
||||
[](InterpreterWrapper& self) {
|
||||
return tensorflow::PyoOrThrow(self.GetSignatureDefs());
|
||||
})
|
||||
.def("ResetVariableTensors",
|
||||
[](InterpreterWrapper& self) {
|
||||
return tensorflow::PyoOrThrow(self.ResetVariableTensors());
|
||||
|
@ -74,10 +74,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
|
||||
self.assertEqual(expected_value.numpy(), actual_value)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_INT8InputOutput', dtypes.int8),
|
||||
('_UINT8InputOutput', dtypes.uint8),
|
||||
('_INT16InputOutput', dtypes.int16))
|
||||
@parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
|
||||
('_UINT8InputOutput', dtypes.uint8),
|
||||
('_INT16InputOutput', dtypes.int16))
|
||||
@test_util.run_v2_only
|
||||
def testInvalidFloat(self, inference_input_output_type):
|
||||
root = self._getSimpleVariableModel()
|
||||
@ -194,10 +193,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_INT8InputOutput', dtypes.int8),
|
||||
('_UINT8InputOutput', dtypes.uint8),
|
||||
('_INT16InputOutput', dtypes.int16))
|
||||
@parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
|
||||
('_UINT8InputOutput', dtypes.uint8),
|
||||
('_INT16InputOutput', dtypes.int16))
|
||||
@test_util.run_v2_only
|
||||
def testInvalidPostTrainingDynamicRangeQuantization(
|
||||
self, inference_input_output_type):
|
||||
@ -227,11 +225,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
|
||||
('_IntOnly', True, False, dtypes.float32),
|
||||
('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
|
||||
('_IntOnly_UINT8InputOutput', True, False,
|
||||
dtypes.uint8),
|
||||
('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
|
||||
('_IntOnly_INT16Quantize', True, True, dtypes.float32),
|
||||
('_IntOnly_INT16Quantize_INT16InputOutput', True, True,
|
||||
dtypes.int16))
|
||||
('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
|
||||
def testIntegerQuantization(self, is_int_only, is_int16_quantize,
|
||||
inference_input_output_type):
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
@ -302,7 +298,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
quantized_converter.inference_output_type = dtypes.int8
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
"The inference_input_type and inference_output_type "
|
||||
'The inference_input_type and inference_output_type '
|
||||
"must be in ['tf.float32', 'tf.int16'].", str(error.exception))
|
||||
|
||||
def testCalibrateAndQuantizeBuiltinInt16(self):
|
||||
@ -380,8 +376,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_DefaultFLOAT32InputOutput', dtypes.float32),
|
||||
('_INT8InputOutput', dtypes.int8),
|
||||
('_UINT8InputOutput', dtypes.uint8))
|
||||
('_INT8InputOutput', dtypes.int8), ('_UINT8InputOutput', dtypes.uint8))
|
||||
@test_util.run_v2_only
|
||||
def testTrainingTimeQuantization(self, inference_input_output_type):
|
||||
model = self._getTrainingTimeQuantizedModel()
|
||||
@ -890,6 +885,85 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
|
||||
self.assertEqual(expected_value.numpy(), actual_value)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testSignatureDefs(self):
|
||||
"""Test converting SignatureDef is correct and uses SignatureDef API."""
|
||||
root = self._getMultiFunctionModel()
|
||||
input_data_0 = tf.constant(1., shape=[1])
|
||||
input_data_1 = tf.constant(3., shape=[1])
|
||||
mul_add_func = root.mul_add.get_concrete_function(input_data_1,
|
||||
input_data_0)
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
save(root, save_dir, {'mul_add': mul_add_func})
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(
|
||||
save_dir, signature_keys=['mul_add'])
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = root.mul_add(input_data_1, input_data_0)
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
signature_defs = interpreter.get_signature_list()
|
||||
results = self._evaluateTFLiteModelUsingSignatureDef(
|
||||
tflite_model, 'mul_add', {
|
||||
'y': input_data_0,
|
||||
'x': input_data_1
|
||||
})
|
||||
self.assertEqual(list(results.keys()), ['output_0'])
|
||||
self.assertEqual(expected_value.numpy(), results['output_0'])
|
||||
|
||||
# Verify the SignatureDef structure returned is as expected.
|
||||
self.assertEqual(len(signature_defs), 1)
|
||||
self.assertEqual(list(signature_defs.keys()), ['mul_add'])
|
||||
self.assertEqual(len(signature_defs.values()), 1)
|
||||
self.assertEqual(
|
||||
list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
|
||||
self.assertEqual(
|
||||
sorted(signature_defs['mul_add']['inputs']), ['x', 'y'])
|
||||
self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testSignatureDefsWithDefaultValue(self):
|
||||
"""Test converting SignatureDef is correct and uses SignatureDef API.
|
||||
|
||||
This test uses None as method_name to test default behavior.
|
||||
"""
|
||||
root = self._getMultiFunctionModel()
|
||||
input_data_0 = tf.constant(1., shape=[1])
|
||||
input_data_1 = tf.constant(3., shape=[1])
|
||||
mul_add_func = root.mul_add.get_concrete_function(input_data_1,
|
||||
input_data_0)
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
save(root, save_dir, {'mul_add': mul_add_func})
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(
|
||||
save_dir, signature_keys=['mul_add'])
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = root.mul_add(input_data_1, input_data_0)
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
signature_defs = interpreter.get_signature_list()
|
||||
results = self._evaluateTFLiteModelUsingSignatureDef(
|
||||
tflite_model, None, {
|
||||
'y': input_data_0,
|
||||
'x': input_data_1
|
||||
})
|
||||
self.assertEqual(list(results.keys()), ['output_0'])
|
||||
self.assertEqual(expected_value.numpy(), results['output_0'])
|
||||
|
||||
# Verify the SignatureDef structure returned is as expected.
|
||||
self.assertEqual(len(signature_defs), 1)
|
||||
self.assertEqual(list(signature_defs.keys()), ['mul_add'])
|
||||
self.assertEqual(len(signature_defs.values()), 1)
|
||||
self.assertEqual(
|
||||
list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
|
||||
self.assertEqual(
|
||||
sorted(signature_defs['mul_add']['inputs']), ['x', 'y'])
|
||||
self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testMultipleFunctionModel(self):
|
||||
"""Convert multiple functions in a multi-functional model."""
|
||||
@ -1542,54 +1616,5 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
str(error.exception))
|
||||
|
||||
|
||||
class AffineOpThenMulFusionTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
@parameterized.named_parameters(('should_fuse_1d', [2], True),
|
||||
('should_fuse_1x2', [1, 2], True),
|
||||
('should_not_fuse_2x1', [2, 1], False),
|
||||
('should_not_fuse_2x2', [2, 2], False))
|
||||
@test_util.run_v2_only
|
||||
def testFullyConnectedFusion(self, multiplier_shape, can_fuse):
|
||||
"""Test fusion of (x ∗ w) * m into fullyconnected."""
|
||||
|
||||
@tf.function
|
||||
def func(x):
|
||||
w = tf.constant([3., 4., 5., 6.], shape=[2, 2])
|
||||
m_value = [7., 8.] if sum(multiplier_shape) < 4 else [7., 8., 9., 10.]
|
||||
m = tf.constant(m_value, shape=multiplier_shape)
|
||||
return tf.matmul(x, w) * m
|
||||
|
||||
input_data = tf.constant([1., 2.], shape=[1, 2])
|
||||
self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
|
||||
|
||||
@parameterized.named_parameters(('should_fuse_1d', [2], True),
|
||||
('should_fuse_1x2', [1, 2], True),
|
||||
('should_not_fuse_2x1', [2, 1], False))
|
||||
@test_util.run_v2_only
|
||||
def testConvFusion(self, multiplier_shape, can_fuse):
|
||||
"""Test fusion of (x ∗ w) * m into conv2d."""
|
||||
|
||||
@tf.function
|
||||
def func(x):
|
||||
w = tf.constant([3., 4., 5., 6.], shape=[2, 1, 1, 2])
|
||||
m = tf.constant([7., 8.], shape=multiplier_shape)
|
||||
return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') * m
|
||||
|
||||
input_data = tf.constant([1., 2.], shape=[1, 1, 2, 1])
|
||||
self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
|
||||
|
||||
def _checkAffineFusion(self, func, input_data, expected_number_of_ops):
|
||||
concrete_func = func.get_concrete_function(input_data)
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
tflite_model = converter.convert()
|
||||
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
assert len(interpreter._get_ops_details()) == expected_number_of_ops
|
||||
|
||||
expected_value = func(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
|
||||
self.assertAllClose(expected_value.numpy(), actual_value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -67,6 +67,24 @@ class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
interpreter.get_tensor(details['index']) for details in output_details
|
||||
]
|
||||
|
||||
def _evaluateTFLiteModelUsingSignatureDef(self, tflite_model, method_name,
|
||||
inputs):
|
||||
"""Evaluates the model on the `inputs`.
|
||||
|
||||
Args:
|
||||
tflite_model: TensorFlow Lite model.
|
||||
method_name: Exported Method name of the SavedModel.
|
||||
inputs: Map from input tensor names in the SignatureDef to tensor value.
|
||||
|
||||
Returns:
|
||||
Dictionary of outputs.
|
||||
Key is the output name in the SignatureDef 'method_name'
|
||||
Value is the output value
|
||||
"""
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
signature_runner = interpreter.get_signature_runner(method_name)
|
||||
return signature_runner(**inputs)
|
||||
|
||||
def _getSimpleVariableModel(self):
|
||||
root = tracking.AutoTrackable()
|
||||
root.v1 = variables.Variable(3.)
|
||||
@ -95,6 +113,12 @@ class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.z = variables.Variable(3.)
|
||||
return x - self.z
|
||||
|
||||
@def_function.function
|
||||
def mul_add(self, x, y):
|
||||
if self.z is None:
|
||||
self.z = variables.Variable(3.)
|
||||
return x * self.z + y
|
||||
|
||||
return BasicModel()
|
||||
|
||||
def _assertValidDebugInfo(self, debug_info):
|
||||
|
@ -18,6 +18,14 @@ tf_class {
|
||||
name: "get_output_details"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_signature_list"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_signature_runner"
|
||||
argspec: "args=[\'self\', \'method_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_tensor"
|
||||
argspec: "args=[\'self\', \'tensor_index\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -18,6 +18,14 @@ tf_class {
|
||||
name: "get_output_details"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_signature_list"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_signature_runner"
|
||||
argspec: "args=[\'self\', \'method_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_tensor"
|
||||
argspec: "args=[\'self\', \'tensor_index\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user