Remove constant folding grappler optimization from MLIR conversion pipeline
PiperOrigin-RevId: 305113276 Change-Id: I9788e15c30929df184e54c2aa557d10d78ebc852
This commit is contained in:
parent
488a1a3c8c
commit
2fb65633e9
tensorflow/lite/python
@ -302,7 +302,9 @@ class TFLiteConverterBase(object):
|
||||
"""
|
||||
if not optimizers:
|
||||
optimizers = []
|
||||
optimizers.append("constfold")
|
||||
# MLIR converter will take care of constant folding instead of grappler.
|
||||
if not self.experimental_new_converter:
|
||||
optimizers.append("constfold")
|
||||
|
||||
is_only_flex_enabled = (
|
||||
set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops))
|
||||
@ -595,12 +597,17 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
output_tensors = frozen_func.outputs
|
||||
|
||||
# Run a Grappler pass.
|
||||
graph_def = _run_graph_optimizations(
|
||||
graph_def,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
config=self._grappler_config(),
|
||||
graph=frozen_func.graph)
|
||||
grappler_config = self._grappler_config()
|
||||
# Skip running grappler when there are no optimizers to run. If not,
|
||||
# grappler will run with the default optimizer set and it will lead to
|
||||
# causing an unexpected behavior.
|
||||
if grappler_config.graph_options.rewrite_options.optimizers:
|
||||
graph_def = _run_graph_optimizations(
|
||||
graph_def,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
config=grappler_config,
|
||||
graph=frozen_func.graph)
|
||||
|
||||
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
|
||||
self.representative_dataset, graph_def)
|
||||
|
@ -2236,6 +2236,37 @@ class GrapplerTest(TestModels, parameterized.TestCase):
|
||||
self.assertEqual('Placeholder', input_details[0]['name'])
|
||||
self.assertEqual('Const', input_details[1]['name'])
|
||||
|
||||
def testGrapplerConstFolding(self):
|
||||
# Constant folding converts the following add operation to tf.broadcast_to
|
||||
# operation which was not supported by the TFLite at the time this test was
|
||||
# added.
|
||||
@def_function.function
|
||||
def plus_placeholder(x, placeholder):
|
||||
return x + placeholder
|
||||
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
|
||||
out_tensor = plus_placeholder(
|
||||
array_ops.zeros([2, 2, 2]),
|
||||
array_ops.reshape(in_tensor, shape=[2, 2]))
|
||||
sess = session.Session()
|
||||
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
# Only disable this path in MLIR conversion for toco compatibility.
|
||||
converter.experimental_new_converter = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual('Placeholder', input_details[0]['name'])
|
||||
|
||||
|
||||
class ImportOpsUtilTest(LiteTest):
|
||||
|
||||
def testGetPotentiallySupportedOps(self):
|
||||
|
Loading…
Reference in New Issue
Block a user