Add verification code for input and output tensors in SavedModel importer
- Verify the given input and output names in the tf.entry_function in MLIR. - Use input and output names with a colon in the saved model path. PiperOrigin-RevId: 305810470 Change-Id: Id7f56ba216db2b60e6e1a11dbbcc0761a66b4635
This commit is contained in:
parent
c003ecc1d6
commit
49b16040a7
@ -16,9 +16,12 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringSet.h"
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
||||||
@ -41,6 +44,77 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
|
||||||
|
mlir::OwningModuleRef* module) {
|
||||||
|
mlir::FuncOp entry_function = nullptr;
|
||||||
|
for (auto func : module->get().getOps<mlir::FuncOp>()) {
|
||||||
|
if (auto tf_attrs =
|
||||||
|
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
|
||||||
|
// TODO(jaesung): There could be multiple entry functions. Let's handle
|
||||||
|
// such cases if there are any needs for that.
|
||||||
|
if (entry_function != nullptr) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"There should be only one tf.entry_function");
|
||||||
|
}
|
||||||
|
entry_function = func;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (entry_function == nullptr) {
|
||||||
|
return errors::InvalidArgument("no tf.entry_function found");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the list of input Op names from the function attribute.
|
||||||
|
mlir::DictionaryAttr tf_attrs =
|
||||||
|
entry_function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
||||||
|
llvm::SmallVector<llvm::StringRef, 4> function_input_names;
|
||||||
|
function_input_names.reserve(model_flags.input_arrays().size());
|
||||||
|
auto input_attr = tf_attrs.get("inputs");
|
||||||
|
if (!input_attr) {
|
||||||
|
return errors::InvalidArgument("no inputs attribute found");
|
||||||
|
}
|
||||||
|
auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
|
||||||
|
input_names.split(function_input_names, ",");
|
||||||
|
if (function_input_names.size() != model_flags.input_arrays().size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"input array size mismatch: got ", function_input_names.size(),
|
||||||
|
", expected: ", model_flags.input_arrays().size());
|
||||||
|
}
|
||||||
|
llvm::StringSet<> function_input_names_set;
|
||||||
|
function_input_names_set.insert(function_input_names.begin(),
|
||||||
|
function_input_names.end());
|
||||||
|
for (const auto& input_array : model_flags.input_arrays()) {
|
||||||
|
if (function_input_names_set.count(input_array.name()) == 0) {
|
||||||
|
return errors::InvalidArgument("input array name (", input_array.name(),
|
||||||
|
") does not exist in the given graph");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the list of output Op names from the function attribute.
|
||||||
|
llvm::SmallVector<llvm::StringRef, 4> function_output_names;
|
||||||
|
function_output_names.reserve(model_flags.output_arrays().size());
|
||||||
|
auto output_attr = tf_attrs.get("outputs");
|
||||||
|
if (!output_attr) {
|
||||||
|
return errors::InvalidArgument("no outputs attribute found");
|
||||||
|
}
|
||||||
|
auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
|
||||||
|
output_names.split(function_output_names, ",");
|
||||||
|
if (function_output_names.size() != model_flags.output_arrays().size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"output array size mismatch: got ", function_output_names.size(),
|
||||||
|
", expected: ", model_flags.output_arrays().size());
|
||||||
|
}
|
||||||
|
llvm::StringSet<> function_output_names_set;
|
||||||
|
function_output_names_set.insert(function_output_names.begin(),
|
||||||
|
function_output_names.end());
|
||||||
|
for (const auto& output_array : model_flags.output_arrays()) {
|
||||||
|
if (function_output_names_set.count(output_array) == 0) {
|
||||||
|
return errors::InvalidArgument("output array name (", output_array,
|
||||||
|
") does not exist in the given graph");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||||
string* result) {
|
string* result) {
|
||||||
@ -77,6 +151,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
|||||||
model_flags.saved_model_version(), tags,
|
model_flags.saved_model_version(), tags,
|
||||||
exported_names, &context));
|
exported_names, &context));
|
||||||
|
|
||||||
|
if (!model_flags.input_arrays().empty() ||
|
||||||
|
!model_flags.output_arrays().empty()) {
|
||||||
|
TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
|
||||||
|
}
|
||||||
|
|
||||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||||
|
|||||||
@ -400,7 +400,10 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
model.change_concat_input_ranges = change_concat_input_ranges
|
model.change_concat_input_ranges = change_concat_input_ranges
|
||||||
for idx, input_tensor in enumerate(input_tensors):
|
for idx, input_tensor in enumerate(input_tensors):
|
||||||
input_array = model.input_arrays.add()
|
input_array = model.input_arrays.add()
|
||||||
input_array.name = util.get_tensor_name(input_tensor)
|
if saved_model_dir:
|
||||||
|
input_array.name = input_tensor.name
|
||||||
|
else:
|
||||||
|
input_array.name = util.get_tensor_name(input_tensor)
|
||||||
input_array.data_type = util.convert_dtype_to_tflite_type(
|
input_array.data_type = util.convert_dtype_to_tflite_type(
|
||||||
input_tensor.dtype)
|
input_tensor.dtype)
|
||||||
|
|
||||||
@ -423,7 +426,10 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
input_array.shape.dims.extend(dims)
|
input_array.shape.dims.extend(dims)
|
||||||
|
|
||||||
for output_tensor in output_tensors:
|
for output_tensor in output_tensors:
|
||||||
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
if saved_model_dir:
|
||||||
|
model.output_arrays.append(output_tensor.name)
|
||||||
|
else:
|
||||||
|
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
||||||
|
|
||||||
model.allow_nonexistent_arrays = allow_nonexistent_arrays
|
model.allow_nonexistent_arrays = allow_nonexistent_arrays
|
||||||
|
|
||||||
|
|||||||
@ -382,6 +382,9 @@ class TFLiteConverterBase(object):
|
|||||||
|
|
||||||
def _parse_saved_model_args(self):
|
def _parse_saved_model_args(self):
|
||||||
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
|
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
|
||||||
|
if not self.experimental_new_converter:
|
||||||
|
self._saved_model_dir = None
|
||||||
|
return
|
||||||
if self._saved_model_dir:
|
if self._saved_model_dir:
|
||||||
try:
|
try:
|
||||||
saved_model_proto, _ = (
|
saved_model_proto, _ = (
|
||||||
|
|||||||
@ -1634,19 +1634,19 @@ class FromSavedModelTest(TestModels):
|
|||||||
|
|
||||||
input_details = interpreter.get_input_details()
|
input_details = interpreter.get_input_details()
|
||||||
self.assertEqual(2, len(input_details))
|
self.assertEqual(2, len(input_details))
|
||||||
self.assertEqual('inputA', input_details[0]['name'])
|
self.assertStartsWith(input_details[0]['name'], 'inputA')
|
||||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||||
|
|
||||||
self.assertEqual('inputB', input_details[1]['name'])
|
self.assertStartsWith(input_details[1]['name'], 'inputB')
|
||||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
||||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||||
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
self.assertEqual(1, len(output_details))
|
self.assertEqual(1, len(output_details))
|
||||||
self.assertEqual('add', output_details[0]['name'])
|
self.assertStartsWith(output_details[0]['name'], 'add')
|
||||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||||
@ -1696,19 +1696,19 @@ class FromSavedModelTest(TestModels):
|
|||||||
|
|
||||||
input_details = interpreter.get_input_details()
|
input_details = interpreter.get_input_details()
|
||||||
self.assertEqual(2, len(input_details))
|
self.assertEqual(2, len(input_details))
|
||||||
self.assertEqual('inputA', input_details[0]['name'])
|
self.assertStartsWith(input_details[0]['name'], 'inputA')
|
||||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||||
|
|
||||||
self.assertEqual('inputB', input_details[1]['name'])
|
self.assertStartsWith(input_details[1]['name'], 'inputB')
|
||||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
||||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||||
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
self.assertEqual(1, len(output_details))
|
self.assertEqual(1, len(output_details))
|
||||||
self.assertEqual('add', output_details[0]['name'])
|
self.assertStartsWith(output_details[0]['name'], 'add')
|
||||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||||
@ -1728,19 +1728,19 @@ class FromSavedModelTest(TestModels):
|
|||||||
|
|
||||||
input_details = interpreter.get_input_details()
|
input_details = interpreter.get_input_details()
|
||||||
self.assertEqual(2, len(input_details))
|
self.assertEqual(2, len(input_details))
|
||||||
self.assertEqual('inputA', input_details[0]['name'])
|
self.assertStartsWith(input_details[0]['name'], 'inputA')
|
||||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||||
|
|
||||||
self.assertEqual('inputB', input_details[1]['name'])
|
self.assertStartsWith(input_details[1]['name'], 'inputB')
|
||||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
||||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||||
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
self.assertEqual(1, len(output_details))
|
self.assertEqual(1, len(output_details))
|
||||||
self.assertEqual('add', output_details[0]['name'])
|
self.assertStartsWith(output_details[0]['name'], 'add')
|
||||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||||
|
|||||||
@ -373,19 +373,22 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
|||||||
|
|
||||||
input_details = interpreter.get_input_details()
|
input_details = interpreter.get_input_details()
|
||||||
self.assertLen(input_details, 2)
|
self.assertLen(input_details, 2)
|
||||||
self.assertEqual('inputA', input_details[0]['name'])
|
self.assertStartsWith(input_details[0]['name'], 'inputA')
|
||||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||||
|
|
||||||
self.assertEqual('inputB', input_details[1]['name'])
|
self.assertStartsWith(
|
||||||
|
input_details[1]['name'],
|
||||||
|
'inputB',
|
||||||
|
)
|
||||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
||||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||||
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
self.assertLen(output_details, 1)
|
self.assertLen(output_details, 1)
|
||||||
self.assertEqual('add', output_details[0]['name'])
|
self.assertStartsWith(output_details[0]['name'], 'add')
|
||||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user