Disable grappler and freezing steps when MLIR SavedModel conversion path is on
PiperOrigin-RevId: 305814559 Change-Id: I04528dcfdab7560531bc7c594ee22b0e5061bb59
This commit is contained in:
parent
f3c3387186
commit
f2100b9b51
@ -72,6 +72,7 @@ from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundEr
|
|||||||
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
||||||
from tensorflow.python.keras.saving import saving_utils as _saving_utils
|
from tensorflow.python.keras.saving import saving_utils as _saving_utils
|
||||||
from tensorflow.python.lib.io import file_io as _file_io
|
from tensorflow.python.lib.io import file_io as _file_io
|
||||||
|
from tensorflow.python.saved_model import loader_impl as _loader_impl
|
||||||
from tensorflow.python.saved_model import signature_constants as _signature_constants
|
from tensorflow.python.saved_model import signature_constants as _signature_constants
|
||||||
from tensorflow.python.saved_model import tag_constants as _tag_constants
|
from tensorflow.python.saved_model import tag_constants as _tag_constants
|
||||||
from tensorflow.python.saved_model.load import load as _load
|
from tensorflow.python.saved_model.load import load as _load
|
||||||
@ -597,28 +598,47 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
self._parse_saved_model_args()
|
self._parse_saved_model_args()
|
||||||
|
|
||||||
# graph_def is used here to preserve the node bug information
|
# graph_def is used here to preserve the node bug information
|
||||||
frozen_func, graph_def = (
|
if self._saved_model_dir:
|
||||||
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
|
graph = _ops.Graph()
|
||||||
self._funcs[0], lower_control_flow=False))
|
saved_model = _loader_impl.SavedModelLoader(self._saved_model_dir)
|
||||||
self._graph_def = graph_def
|
saved_model.load_graph(graph, tags=self._saved_model_tags)
|
||||||
input_tensors = [
|
meta_graph = saved_model.get_meta_graph_def_from_tags(
|
||||||
tensor for tensor in frozen_func.inputs
|
self._saved_model_tags)
|
||||||
if tensor.dtype != _dtypes.resource
|
signature_def = meta_graph.signature_def[
|
||||||
]
|
_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||||
output_tensors = frozen_func.outputs
|
input_tensors = [
|
||||||
|
graph.get_tensor_by_name(signature_def.inputs[key].name)
|
||||||
|
for key in signature_def.inputs
|
||||||
|
]
|
||||||
|
output_tensors = [
|
||||||
|
graph.get_tensor_by_name(signature_def.outputs[key].name)
|
||||||
|
for key in signature_def.outputs
|
||||||
|
]
|
||||||
|
self._graph_def = graph_def = meta_graph.graph_def
|
||||||
|
else:
|
||||||
|
frozen_func, graph_def = (
|
||||||
|
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
|
||||||
|
self._funcs[0], lower_control_flow=False))
|
||||||
|
self._graph_def = graph_def
|
||||||
|
|
||||||
# Run a Grappler pass.
|
input_tensors = [
|
||||||
grappler_config = self._grappler_config()
|
tensor for tensor in frozen_func.inputs
|
||||||
# Skip running grappler when there are no optimizers to run. If not,
|
if tensor.dtype != _dtypes.resource
|
||||||
# grappler will run with the default optimizer set and it will lead to
|
]
|
||||||
# causing an unexpected behavior.
|
output_tensors = frozen_func.outputs
|
||||||
if grappler_config.graph_options.rewrite_options.optimizers:
|
|
||||||
graph_def = _run_graph_optimizations(
|
# Run a Grappler pass.
|
||||||
graph_def,
|
grappler_config = self._grappler_config()
|
||||||
input_tensors,
|
# Skip running grappler when there are no optimizers to run. If not,
|
||||||
output_tensors,
|
# grappler will run with the default optimizer set and it will lead to
|
||||||
config=grappler_config,
|
# causing an unexpected behavior.
|
||||||
graph=frozen_func.graph)
|
if grappler_config.graph_options.rewrite_options.optimizers:
|
||||||
|
graph_def = _run_graph_optimizations(
|
||||||
|
graph_def,
|
||||||
|
input_tensors,
|
||||||
|
output_tensors,
|
||||||
|
config=grappler_config,
|
||||||
|
graph=frozen_func.graph)
|
||||||
|
|
||||||
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
|
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
|
||||||
self.representative_dataset, graph_def)
|
self.representative_dataset, graph_def)
|
||||||
@ -1231,28 +1251,29 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
"are not enabled.")
|
"are not enabled.")
|
||||||
|
|
||||||
optimized_graph = self._graph_def
|
optimized_graph = self._graph_def
|
||||||
# if it is not uint8 or int8 with post-training quantization, it is not
|
if not self._saved_model_dir:
|
||||||
# quantization aware training, then graph optimization is applied.
|
# if it is not uint8 or int8 with post-training quantization, it is not
|
||||||
# Graph optimization is disabled for quantization aware training.
|
# quantization aware training, then graph optimization is applied.
|
||||||
if (self.inference_type != constants.QUANTIZED_UINT8 or
|
# Graph optimization is disabled for quantization aware training.
|
||||||
(self.inference_type == constants.INT8 and
|
if (self.inference_type != constants.QUANTIZED_UINT8 or
|
||||||
(post_training_optimize or weight_only_quantize))):
|
(self.inference_type == constants.INT8 and
|
||||||
try:
|
(post_training_optimize or weight_only_quantize))):
|
||||||
# TODO(b/150163103): Merge `disabling lower using switch merge' calls.
|
try:
|
||||||
# Grappler will also try to lower while loop into switch merge
|
# TODO(b/150163103): Merge `disabling lower using switch merge' calls.
|
||||||
# representation which is undesired for Ophints, so we simply remove
|
# Grappler will also try to lower while loop into switch merge
|
||||||
# those attributes to prevent Grappler from doing so.
|
# representation which is undesired for Ophints, so we simply remove
|
||||||
graph_def = _convert_to_constants.disable_lower_using_switch_merge(
|
# those attributes to prevent Grappler from doing so.
|
||||||
optimized_graph)
|
graph_def = _convert_to_constants.disable_lower_using_switch_merge(
|
||||||
# Run function inlining optimization to ensure any models generated
|
optimized_graph)
|
||||||
# through the from_frozen_graph path have been inlined.
|
# Run function inlining optimization to ensure any models generated
|
||||||
optimized_graph = _run_graph_optimizations(
|
# through the from_frozen_graph path have been inlined.
|
||||||
graph_def,
|
optimized_graph = _run_graph_optimizations(
|
||||||
self._input_tensors,
|
graph_def,
|
||||||
self._output_tensors,
|
self._input_tensors,
|
||||||
config=self._grappler_config(["function"]))
|
self._output_tensors,
|
||||||
except Exception:
|
config=self._grappler_config(["function"]))
|
||||||
optimized_graph = self._graph_def
|
except Exception:
|
||||||
|
optimized_graph = self._graph_def
|
||||||
|
|
||||||
self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
|
self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user