Disable grappler and freezing steps when MLIR SavedModel conversion path is on
PiperOrigin-RevId: 305814559 Change-Id: I04528dcfdab7560531bc7c594ee22b0e5061bb59
This commit is contained in:
		
							parent
							
								
									f3c3387186
								
							
						
					
					
						commit
						f2100b9b51
					
				@ -72,6 +72,7 @@ from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundEr
 | 
			
		||||
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
 | 
			
		||||
from tensorflow.python.keras.saving import saving_utils as _saving_utils
 | 
			
		||||
from tensorflow.python.lib.io import file_io as _file_io
 | 
			
		||||
from tensorflow.python.saved_model import loader_impl as _loader_impl
 | 
			
		||||
from tensorflow.python.saved_model import signature_constants as _signature_constants
 | 
			
		||||
from tensorflow.python.saved_model import tag_constants as _tag_constants
 | 
			
		||||
from tensorflow.python.saved_model.load import load as _load
 | 
			
		||||
@ -597,10 +598,29 @@ class TFLiteConverterV2(TFLiteConverterBase):
 | 
			
		||||
    self._parse_saved_model_args()
 | 
			
		||||
 | 
			
		||||
    # graph_def is used here to preserve the node bug information
 | 
			
		||||
    if self._saved_model_dir:
 | 
			
		||||
      graph = _ops.Graph()
 | 
			
		||||
      saved_model = _loader_impl.SavedModelLoader(self._saved_model_dir)
 | 
			
		||||
      saved_model.load_graph(graph, tags=self._saved_model_tags)
 | 
			
		||||
      meta_graph = saved_model.get_meta_graph_def_from_tags(
 | 
			
		||||
          self._saved_model_tags)
 | 
			
		||||
      signature_def = meta_graph.signature_def[
 | 
			
		||||
          _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
 | 
			
		||||
      input_tensors = [
 | 
			
		||||
          graph.get_tensor_by_name(signature_def.inputs[key].name)
 | 
			
		||||
          for key in signature_def.inputs
 | 
			
		||||
      ]
 | 
			
		||||
      output_tensors = [
 | 
			
		||||
          graph.get_tensor_by_name(signature_def.outputs[key].name)
 | 
			
		||||
          for key in signature_def.outputs
 | 
			
		||||
      ]
 | 
			
		||||
      self._graph_def = graph_def = meta_graph.graph_def
 | 
			
		||||
    else:
 | 
			
		||||
      frozen_func, graph_def = (
 | 
			
		||||
          _convert_to_constants.convert_variables_to_constants_v2_as_graph(
 | 
			
		||||
              self._funcs[0], lower_control_flow=False))
 | 
			
		||||
      self._graph_def = graph_def
 | 
			
		||||
 | 
			
		||||
      input_tensors = [
 | 
			
		||||
          tensor for tensor in frozen_func.inputs
 | 
			
		||||
          if tensor.dtype != _dtypes.resource
 | 
			
		||||
@ -1231,6 +1251,7 @@ class TFLiteConverter(TFLiteConverterBase):
 | 
			
		||||
          "are not enabled.")
 | 
			
		||||
 | 
			
		||||
    optimized_graph = self._graph_def
 | 
			
		||||
    if not self._saved_model_dir:
 | 
			
		||||
      # if it is not uint8 or int8 with post-training quantization, it is not
 | 
			
		||||
      # quantization aware training, then graph optimization is applied.
 | 
			
		||||
      # Graph optimization is disabled for quantization aware training.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user