Add verification code for input and output tensors in SavedModel importer
- Verify the given input and output names in the tf.entry_function in MLIR. - Use input and output names with a colon in the saved model path. PiperOrigin-RevId: 305810470 Change-Id: Id7f56ba216db2b60e6e1a11dbbcc0761a66b4635
This commit is contained in:
parent
c003ecc1d6
commit
49b16040a7
@ -16,9 +16,12 @@ limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
||||
@ -41,6 +44,77 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
|
||||
mlir::OwningModuleRef* module) {
|
||||
mlir::FuncOp entry_function = nullptr;
|
||||
for (auto func : module->get().getOps<mlir::FuncOp>()) {
|
||||
if (auto tf_attrs =
|
||||
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
|
||||
// TODO(jaesung): There could be multiple entry functions. Let's handle
|
||||
// such cases if there are any needs for that.
|
||||
if (entry_function != nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"There should be only one tf.entry_function");
|
||||
}
|
||||
entry_function = func;
|
||||
}
|
||||
}
|
||||
if (entry_function == nullptr) {
|
||||
return errors::InvalidArgument("no tf.entry_function found");
|
||||
}
|
||||
|
||||
// Get the list of input Op names from the function attribute.
|
||||
mlir::DictionaryAttr tf_attrs =
|
||||
entry_function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
||||
llvm::SmallVector<llvm::StringRef, 4> function_input_names;
|
||||
function_input_names.reserve(model_flags.input_arrays().size());
|
||||
auto input_attr = tf_attrs.get("inputs");
|
||||
if (!input_attr) {
|
||||
return errors::InvalidArgument("no inputs attribute found");
|
||||
}
|
||||
auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
|
||||
input_names.split(function_input_names, ",");
|
||||
if (function_input_names.size() != model_flags.input_arrays().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"input array size mismatch: got ", function_input_names.size(),
|
||||
", expected: ", model_flags.input_arrays().size());
|
||||
}
|
||||
llvm::StringSet<> function_input_names_set;
|
||||
function_input_names_set.insert(function_input_names.begin(),
|
||||
function_input_names.end());
|
||||
for (const auto& input_array : model_flags.input_arrays()) {
|
||||
if (function_input_names_set.count(input_array.name()) == 0) {
|
||||
return errors::InvalidArgument("input array name (", input_array.name(),
|
||||
") does not exist in the given graph");
|
||||
}
|
||||
}
|
||||
|
||||
// Get the list of output Op names from the function attribute.
|
||||
llvm::SmallVector<llvm::StringRef, 4> function_output_names;
|
||||
function_output_names.reserve(model_flags.output_arrays().size());
|
||||
auto output_attr = tf_attrs.get("outputs");
|
||||
if (!output_attr) {
|
||||
return errors::InvalidArgument("no outputs attribute found");
|
||||
}
|
||||
auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
|
||||
output_names.split(function_output_names, ",");
|
||||
if (function_output_names.size() != model_flags.output_arrays().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"output array size mismatch: got ", function_output_names.size(),
|
||||
", expected: ", model_flags.output_arrays().size());
|
||||
}
|
||||
llvm::StringSet<> function_output_names_set;
|
||||
function_output_names_set.insert(function_output_names.begin(),
|
||||
function_output_names.end());
|
||||
for (const auto& output_array : model_flags.output_arrays()) {
|
||||
if (function_output_names_set.count(output_array) == 0) {
|
||||
return errors::InvalidArgument("output array name (", output_array,
|
||||
") does not exist in the given graph");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
string* result) {
|
||||
@ -77,6 +151,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
model_flags.saved_model_version(), tags,
|
||||
exported_names, &context));
|
||||
|
||||
if (!model_flags.input_arrays().empty() ||
|
||||
!model_flags.output_arrays().empty()) {
|
||||
TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
|
||||
}
|
||||
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
|
@ -400,7 +400,10 @@ def build_toco_convert_protos(input_tensors,
|
||||
model.change_concat_input_ranges = change_concat_input_ranges
|
||||
for idx, input_tensor in enumerate(input_tensors):
|
||||
input_array = model.input_arrays.add()
|
||||
input_array.name = util.get_tensor_name(input_tensor)
|
||||
if saved_model_dir:
|
||||
input_array.name = input_tensor.name
|
||||
else:
|
||||
input_array.name = util.get_tensor_name(input_tensor)
|
||||
input_array.data_type = util.convert_dtype_to_tflite_type(
|
||||
input_tensor.dtype)
|
||||
|
||||
@ -423,7 +426,10 @@ def build_toco_convert_protos(input_tensors,
|
||||
input_array.shape.dims.extend(dims)
|
||||
|
||||
for output_tensor in output_tensors:
|
||||
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
||||
if saved_model_dir:
|
||||
model.output_arrays.append(output_tensor.name)
|
||||
else:
|
||||
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
||||
|
||||
model.allow_nonexistent_arrays = allow_nonexistent_arrays
|
||||
|
||||
|
@ -382,6 +382,9 @@ class TFLiteConverterBase(object):
|
||||
|
||||
def _parse_saved_model_args(self):
|
||||
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
|
||||
if not self.experimental_new_converter:
|
||||
self._saved_model_dir = None
|
||||
return
|
||||
if self._saved_model_dir:
|
||||
try:
|
||||
saved_model_proto, _ = (
|
||||
|
@ -1634,19 +1634,19 @@ class FromSavedModelTest(TestModels):
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(2, len(input_details))
|
||||
self.assertEqual('inputA', input_details[0]['name'])
|
||||
self.assertStartsWith(input_details[0]['name'], 'inputA')
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
self.assertEqual('inputB', input_details[1]['name'])
|
||||
self.assertStartsWith(input_details[1]['name'], 'inputB')
|
||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
self.assertEqual('add', output_details[0]['name'])
|
||||
self.assertStartsWith(output_details[0]['name'], 'add')
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||
@ -1696,19 +1696,19 @@ class FromSavedModelTest(TestModels):
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(2, len(input_details))
|
||||
self.assertEqual('inputA', input_details[0]['name'])
|
||||
self.assertStartsWith(input_details[0]['name'], 'inputA')
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
self.assertEqual('inputB', input_details[1]['name'])
|
||||
self.assertStartsWith(input_details[1]['name'], 'inputB')
|
||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
self.assertEqual('add', output_details[0]['name'])
|
||||
self.assertStartsWith(output_details[0]['name'], 'add')
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||
@ -1728,19 +1728,19 @@ class FromSavedModelTest(TestModels):
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(2, len(input_details))
|
||||
self.assertEqual('inputA', input_details[0]['name'])
|
||||
self.assertStartsWith(input_details[0]['name'], 'inputA')
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
self.assertEqual('inputB', input_details[1]['name'])
|
||||
self.assertStartsWith(input_details[1]['name'], 'inputB')
|
||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
self.assertEqual('add', output_details[0]['name'])
|
||||
self.assertStartsWith(output_details[0]['name'], 'add')
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||
|
@ -373,19 +373,22 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 2)
|
||||
self.assertEqual('inputA', input_details[0]['name'])
|
||||
self.assertStartsWith(input_details[0]['name'], 'inputA')
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
self.assertEqual('inputB', input_details[1]['name'])
|
||||
self.assertStartsWith(
|
||||
input_details[1]['name'],
|
||||
'inputB',
|
||||
)
|
||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertLen(output_details, 1)
|
||||
self.assertEqual('add', output_details[0]['name'])
|
||||
self.assertStartsWith(output_details[0]['name'], 'add')
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||
|
Loading…
Reference in New Issue
Block a user