diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 4db28058cd0..c338b723a4a 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -16,9 +16,12 @@ limitations under the License. #include +#include "llvm/ADT/StringSet.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project @@ -41,6 +44,77 @@ limitations under the License. namespace tensorflow { +Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags, + mlir::OwningModuleRef* module) { + mlir::FuncOp entry_function = nullptr; + for (auto func : module->get().getOps()) { + if (auto tf_attrs = + func.getAttrOfType("tf.entry_function")) { + // TODO(jaesung): There could be multiple entry functions. Let's handle + // such cases if there are any needs for that. + if (entry_function != nullptr) { + return errors::InvalidArgument( + "There should be only one tf.entry_function"); + } + entry_function = func; + } + } + if (entry_function == nullptr) { + return errors::InvalidArgument("no tf.entry_function found"); + } + + // Get the list of input Op names from the function attribute. + mlir::DictionaryAttr tf_attrs = + entry_function.getAttrOfType("tf.entry_function"); + llvm::SmallVector function_input_names; + function_input_names.reserve(model_flags.input_arrays().size()); + auto input_attr = tf_attrs.get("inputs"); + if (!input_attr) { + return errors::InvalidArgument("no inputs attribute found"); + } + auto input_names = input_attr.cast().getValue(); + input_names.split(function_input_names, ","); + if (function_input_names.size() != model_flags.input_arrays().size()) { + return errors::InvalidArgument( + "input array size mismatch: got ", function_input_names.size(), + ", expected: ", model_flags.input_arrays().size()); + } + llvm::StringSet<> function_input_names_set; + function_input_names_set.insert(function_input_names.begin(), + function_input_names.end()); + for (const auto& input_array : model_flags.input_arrays()) { + if (function_input_names_set.count(input_array.name()) == 0) { + return errors::InvalidArgument("input array name (", input_array.name(), + ") does not exist in the given graph"); + } + } + + // Get the list of output Op names from the function attribute. + llvm::SmallVector function_output_names; + function_output_names.reserve(model_flags.output_arrays().size()); + auto output_attr = tf_attrs.get("outputs"); + if (!output_attr) { + return errors::InvalidArgument("no outputs attribute found"); + } + auto output_names = output_attr.cast().getValue(); + output_names.split(function_output_names, ","); + if (function_output_names.size() != model_flags.output_arrays().size()) { + return errors::InvalidArgument( + "output array size mismatch: got ", function_output_names.size(), + ", expected: ", model_flags.output_arrays().size()); + } + llvm::StringSet<> function_output_names_set; + function_output_names_set.insert(function_output_names.begin(), + function_output_names.end()); + for (const auto& output_array : model_flags.output_arrays()) { + if (function_output_names_set.count(output_array) == 0) { + return errors::InvalidArgument("output array name (", output_array, + ") does not exist in the given graph"); + } + } + return Status::OK(); +} + Status ConvertSavedModelToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, string* result) { @@ -77,6 +151,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer( model_flags.saved_model_version(), tags, exported_names, &context)); + if (!model_flags.input_arrays().empty() || + !model_flags.output_arrays().empty()) { + TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module)); + } + mlir::TFL::PassConfig pass_config(quant_specs); bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 0b140ec3826..bf9bee02971 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -400,7 +400,10 @@ def build_toco_convert_protos(input_tensors, model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): input_array = model.input_arrays.add() - input_array.name = util.get_tensor_name(input_tensor) + if saved_model_dir: + input_array.name = input_tensor.name + else: + input_array.name = util.get_tensor_name(input_tensor) input_array.data_type = util.convert_dtype_to_tflite_type( input_tensor.dtype) @@ -423,7 +426,10 @@ def build_toco_convert_protos(input_tensors, input_array.shape.dims.extend(dims) for output_tensor in output_tensors: - model.output_arrays.append(util.get_tensor_name(output_tensor)) + if saved_model_dir: + model.output_arrays.append(output_tensor.name) + else: + model.output_arrays.append(util.get_tensor_name(output_tensor)) model.allow_nonexistent_arrays = allow_nonexistent_arrays diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 83b0e2b734c..fef8c9ce3cf 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -382,6 +382,9 @@ class TFLiteConverterBase(object): def _parse_saved_model_args(self): """Parses SavedModel arguments from the given Keras/RNN SavedModel.""" + if not self.experimental_new_converter: + self._saved_model_dir = None + return if self._saved_model_dir: try: saved_model_proto, _ = ( diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 9d9f17547b0..445a8b4cfed 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -1634,19 +1634,19 @@ class FromSavedModelTest(TestModels): input_details = interpreter.get_input_details() self.assertEqual(2, len(input_details)) - self.assertEqual('inputA', input_details[0]['name']) + self.assertStartsWith(input_details[0]['name'], 'inputA') self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) - self.assertEqual('inputB', input_details[1]['name']) + self.assertStartsWith(input_details[1]['name'], 'inputB') self.assertEqual(np.float32, input_details[1]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) self.assertEqual((0., 0.), input_details[1]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) - self.assertEqual('add', output_details[0]['name']) + self.assertStartsWith(output_details[0]['name'], 'add') self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) @@ -1696,19 +1696,19 @@ class FromSavedModelTest(TestModels): input_details = interpreter.get_input_details() self.assertEqual(2, len(input_details)) - self.assertEqual('inputA', input_details[0]['name']) + self.assertStartsWith(input_details[0]['name'], 'inputA') self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) - self.assertEqual('inputB', input_details[1]['name']) + self.assertStartsWith(input_details[1]['name'], 'inputB') self.assertEqual(np.float32, input_details[1]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) self.assertEqual((0., 0.), input_details[1]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) - self.assertEqual('add', output_details[0]['name']) + self.assertStartsWith(output_details[0]['name'], 'add') self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) @@ -1728,19 +1728,19 @@ class FromSavedModelTest(TestModels): input_details = interpreter.get_input_details() self.assertEqual(2, len(input_details)) - self.assertEqual('inputA', input_details[0]['name']) + self.assertStartsWith(input_details[0]['name'], 'inputA') self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) - self.assertEqual('inputB', input_details[1]['name']) + self.assertStartsWith(input_details[1]['name'], 'inputB') self.assertEqual(np.float32, input_details[1]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) self.assertEqual((0., 0.), input_details[1]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) - self.assertEqual('add', output_details[0]['name']) + self.assertStartsWith(output_details[0]['name'], 'add') self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 9512bdca70d..d04117c1a32 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -373,19 +373,22 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): input_details = interpreter.get_input_details() self.assertLen(input_details, 2) - self.assertEqual('inputA', input_details[0]['name']) + self.assertStartsWith(input_details[0]['name'], 'inputA') self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) - self.assertEqual('inputB', input_details[1]['name']) + self.assertStartsWith( + input_details[1]['name'], + 'inputB', + ) self.assertEqual(np.float32, input_details[1]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) self.assertEqual((0., 0.), input_details[1]['quantization']) output_details = interpreter.get_output_details() self.assertLen(output_details, 1) - self.assertEqual('add', output_details[0]['name']) + self.assertStartsWith(output_details[0]['name'], 'add') self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization'])