Make TFLiteConverter build with MLIR internally by default.
PiperOrigin-RevId: 259064943
This commit is contained in:
parent
6e92ec7e92
commit
4b0a805992
@ -40,28 +40,6 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
def mlir_convert_and_check_for_unsupported(test_object, converter):
|
||||
"""Run the converter but don't fail MLIR was not built.
|
||||
|
||||
Args:
|
||||
test_object: PyTest object.
|
||||
converter: A TFLiteConverter
|
||||
|
||||
Returns:
|
||||
The converted TF lite model or None if mlir support is not builtinto the
|
||||
binary.
|
||||
"""
|
||||
try:
|
||||
model = converter.convert()
|
||||
test_object.assertTrue(model)
|
||||
return model
|
||||
except lite.ConverterError as e:
|
||||
if not e.message.startswith('This flag is not supported by this version'):
|
||||
raise e
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@test_util.run_v1_only('Incompatible with 2.0.')
|
||||
class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@ -75,9 +53,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
converter.experimental_enable_mlir_converter = True
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
@ -105,9 +81,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
@ -144,9 +118,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
'inputA': (0., 1.),
|
||||
'inputB': (0., 1.)
|
||||
} # mean, std_dev
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
@ -182,9 +154,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
# Test conversion with the scalar input shape.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
@ -228,18 +198,13 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
# Convert float model.
|
||||
float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
|
||||
[out_tensor])
|
||||
float_tflite = mlir_convert_and_check_for_unsupported(self, float_converter)
|
||||
if float_tflite is None:
|
||||
return
|
||||
float_tflite = float_converter.convert()
|
||||
|
||||
# Convert quantized weights model.
|
||||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [in_tensor_1], [out_tensor])
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_tflite = mlir_convert_and_check_for_unsupported(
|
||||
self, quantized_converter)
|
||||
if quantized_tflite is None:
|
||||
return
|
||||
quantized_tflite = quantized_converter.convert()
|
||||
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
@ -266,9 +231,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [placeholder],
|
||||
[output_node])
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
@ -322,9 +285,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase):
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_enable_mlir_converter = True
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = root.f(input_data)
|
||||
@ -359,9 +320,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase):
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_enable_mlir_converter = True
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = concrete_func(**input_data)
|
||||
@ -389,9 +348,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase):
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_enable_mlir_converter = True
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = concrete_func(input_data)[0]
|
||||
@ -422,9 +379,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase):
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_enable_mlir_converter = True
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = concrete_func(input_data)
|
||||
@ -449,9 +404,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase):
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_enable_mlir_converter = True
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = concrete_func(input_data)
|
||||
@ -478,9 +431,7 @@ class TestFlexMode(test_util.TensorFlowTestCase):
|
||||
[out_tensor])
|
||||
converter.experimental_enable_mlir_converter = True
|
||||
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Ensures the model contains TensorFlow ops.
|
||||
# TODO(nupurgarg): Check values once there is a Python delegate interface.
|
||||
@ -505,10 +456,7 @@ class TestFlexMode(test_util.TensorFlowTestCase):
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_enable_mlir_converter = True
|
||||
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
|
||||
|
||||
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
|
||||
if tflite_model is None:
|
||||
return
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Ensures the model contains TensorFlow ops.
|
||||
# TODO(nupurgarg): Check values once there is a Python delegate interface.
|
||||
|
@ -1,5 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load("//tensorflow:tensorflow.bzl", "if_mlir", "py_binary", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "if_mlir_tflite", "py_binary", "tf_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -22,7 +22,7 @@ cc_library(
|
||||
name = "toco_python_api",
|
||||
srcs = ["toco_python_api.cc"],
|
||||
hdrs = ["toco_python_api.h"],
|
||||
defines = if_mlir(
|
||||
defines = if_mlir_tflite(
|
||||
if_false = [],
|
||||
if_true = ["TFLITE_BUILD_WITH_MLIR_CONVERTER"],
|
||||
),
|
||||
@ -46,7 +46,7 @@ cc_library(
|
||||
"//tensorflow/core:ops",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + if_mlir(
|
||||
}) + if_mlir_tflite(
|
||||
if_false = [],
|
||||
if_true = ["//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer"],
|
||||
),
|
||||
|
@ -2493,5 +2493,8 @@ def if_mlir(if_true, if_false = []):
|
||||
"//tensorflow:with_mlir_support": if_true,
|
||||
})
|
||||
|
||||
def if_mlir_tflite(if_true, if_false = []):
|
||||
return if_mlir(if_true, if_false)
|
||||
|
||||
def tfcompile_extra_flags():
|
||||
return ""
|
||||
|
Loading…
Reference in New Issue
Block a user