Add constant folding Grappler pass to TFLiteConverter.
PiperOrigin-RevId: 246536988
This commit is contained in:
parent
56764feeb6
commit
2d68c0629c
@ -154,12 +154,13 @@ class TFLiteConverterBase(object):
|
||||
|
||||
def _grappler_config(self):
|
||||
is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == set(self._target_ops)
|
||||
optimizers = ["constfold"]
|
||||
if is_only_flex_enabled:
|
||||
# The layout optimizer turns NHCW to NCHW. This provides performance
|
||||
# optimizations when Flex mode is enabled. However, this is not compatible
|
||||
# with builtin ops.
|
||||
return _get_grappler_config(["layout"])
|
||||
return None
|
||||
optimizers.append("layout")
|
||||
return _get_grappler_config(optimizers)
|
||||
|
||||
def _validate_representative_dataset(self):
|
||||
if self.representative_dataset:
|
||||
@ -350,14 +351,12 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
|
||||
# Run a Grappler pass.
|
||||
graph_def = frozen_func.graph.as_graph_def()
|
||||
config = self._grappler_config()
|
||||
if config:
|
||||
graph_def = _run_graph_optimizations(
|
||||
graph_def,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
config,
|
||||
graph=frozen_func.graph)
|
||||
graph_def = _run_graph_optimizations(
|
||||
graph_def,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
config=self._grappler_config(),
|
||||
graph=frozen_func.graph)
|
||||
|
||||
# Checks dimensions in input tensor.
|
||||
for tensor in input_tensors:
|
||||
@ -879,12 +878,11 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
optimized_graph = self._graph_def
|
||||
if self.inference_type != constants.QUANTIZED_UINT8:
|
||||
try:
|
||||
config = self._grappler_config()
|
||||
if config:
|
||||
optimized_graph = _run_graph_optimizations(self._graph_def,
|
||||
self._input_tensors,
|
||||
self._output_tensors,
|
||||
config)
|
||||
optimized_graph = _run_graph_optimizations(
|
||||
self._graph_def,
|
||||
self._input_tensors,
|
||||
self._output_tensors,
|
||||
config=self._grappler_config())
|
||||
except Exception:
|
||||
optimized_graph = self._graph_def
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -1589,6 +1590,40 @@ class FromKerasFile(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
|
||||
@test_util.run_v1_only('Incompatible with 2.0.')
|
||||
class GrapplerTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testConstantFolding(self):
|
||||
# Constant folding handles the tf.broadcast_to operation which was not
|
||||
# supported by the TFLite at the time this test was added.
|
||||
in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32)
|
||||
y_const = constant_op.constant([1., 2., 3.])
|
||||
y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3])
|
||||
out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output')
|
||||
sess = session.Session()
|
||||
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(1, len(input_details))
|
||||
self.assertEqual('Placeholder', input_details[0]['name'])
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([3, 3] == input_details[0]['shape']).all())
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
self.assertEqual('output', output_details[0]['name'])
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([3, 3] == output_details[0]['shape']).all())
|
||||
|
||||
|
||||
class ImportOpsUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testGetPotentiallySupportedOps(self):
|
||||
|
||||
@ -30,6 +30,8 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -396,5 +398,34 @@ class FromKerasModelTest(TestModels):
|
||||
np.testing.assert_almost_equal(tf_result[0], tflite_result, 5)
|
||||
|
||||
|
||||
class GrapplerTest(TestModels):
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testConstantFolding(self):
|
||||
# Constant folding handles the tf.broadcast_to operation which was not
|
||||
# supported by the TFLite at the time this test was added.
|
||||
input_data = constant_op.constant([1., 2., 3., 4., 5., 6., 7., 8., 9.],
|
||||
shape=[3, 3])
|
||||
|
||||
@def_function.function
|
||||
def func(x):
|
||||
y_const = constant_op.constant([1., 2., 3.])
|
||||
y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3])
|
||||
return math_ops.matmul(x, y_broadcast)
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
concrete_func = root.f.get_concrete_function(input_data)
|
||||
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = root.f(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
|
||||
np.testing.assert_array_equal(expected_value.numpy(), actual_value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user