Add constant folding Grappler pass to TFLiteConverter.

PiperOrigin-RevId: 246536988
This commit is contained in:
Nupur Garg 2019-05-03 10:54:30 -07:00 committed by TensorFlower Gardener
parent 56764feeb6
commit 2d68c0629c
3 changed files with 80 additions and 16 deletions

View File

@ -154,12 +154,13 @@ class TFLiteConverterBase(object):
def _grappler_config(self):
is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == set(self._target_ops)
optimizers = ["constfold"]
if is_only_flex_enabled:
# The layout optimizer turns NHCW to NCHW. This provides performance
# optimizations when Flex mode is enabled. However, this is not compatible
# with builtin ops.
return _get_grappler_config(["layout"])
return None
optimizers.append("layout")
return _get_grappler_config(optimizers)
def _validate_representative_dataset(self):
if self.representative_dataset:
@ -350,14 +351,12 @@ class TFLiteConverterV2(TFLiteConverterBase):
# Run a Grappler pass.
graph_def = frozen_func.graph.as_graph_def()
config = self._grappler_config()
if config:
graph_def = _run_graph_optimizations(
graph_def,
input_tensors,
output_tensors,
config,
graph=frozen_func.graph)
graph_def = _run_graph_optimizations(
graph_def,
input_tensors,
output_tensors,
config=self._grappler_config(),
graph=frozen_func.graph)
# Checks dimensions in input tensor.
for tensor in input_tensors:
@ -879,12 +878,11 @@ class TFLiteConverter(TFLiteConverterBase):
optimized_graph = self._graph_def
if self.inference_type != constants.QUANTIZED_UINT8:
try:
config = self._grappler_config()
if config:
optimized_graph = _run_graph_optimizations(self._graph_def,
self._input_tensors,
self._output_tensors,
config)
optimized_graph = _run_graph_optimizations(
self._graph_def,
self._input_tensors,
self._output_tensors,
config=self._grappler_config())
except Exception:
optimized_graph = self._graph_def

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
@ -1589,6 +1590,40 @@ class FromKerasFile(test_util.TensorFlowTestCase):
self.assertTrue(tflite_model)
@test_util.run_v1_only('Incompatible with 2.0.')
class GrapplerTest(test_util.TensorFlowTestCase):
def testConstantFolding(self):
# Constant folding handles the tf.broadcast_to operation which was not
# supported by the TFLite at the time this test was added.
in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32)
y_const = constant_op.constant([1., 2., 3.])
y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3])
out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output')
sess = session.Session()
# Convert model.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
tflite_model = converter.convert()
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([3, 3] == input_details[0]['shape']).all())
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([3, 3] == output_details[0]['shape']).all())
class ImportOpsUtilTest(test_util.TensorFlowTestCase):
def testGetPotentiallySupportedOps(self):

View File

@ -30,6 +30,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -396,5 +398,34 @@ class FromKerasModelTest(TestModels):
np.testing.assert_almost_equal(tf_result[0], tflite_result, 5)
class GrapplerTest(TestModels):
@test_util.run_v2_only
def testConstantFolding(self):
# Constant folding handles the tf.broadcast_to operation which was not
# supported by the TFLite at the time this test was added.
input_data = constant_op.constant([1., 2., 3., 4., 5., 6., 7., 8., 9.],
shape=[3, 3])
@def_function.function
def func(x):
y_const = constant_op.constant([1., 2., 3.])
y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3])
return math_ops.matmul(x, y_broadcast)
root = tracking.AutoTrackable()
root.f = func
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
np.testing.assert_array_equal(expected_value.numpy(), actual_value)
if __name__ == '__main__':
test.main()