Add verification code for input and output tensors in SavedModel importer

- Verify the given input and output names in the tf.entry_function in MLIR.
- Use input and output names with a colon in the saved model path.

PiperOrigin-RevId: 305810470
Change-Id: Id7f56ba216db2b60e6e1a11dbbcc0761a66b4635
This commit is contained in:
Jaesung Chung 2020-04-09 19:44:44 -07:00 committed by TensorFlower Gardener
parent c003ecc1d6
commit 49b16040a7
5 changed files with 105 additions and 14 deletions

View File

@ -16,9 +16,12 @@ limitations under the License.
#include <utility>
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
@ -41,6 +44,77 @@ limitations under the License.
namespace tensorflow {
Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
mlir::OwningModuleRef* module) {
mlir::FuncOp entry_function = nullptr;
for (auto func : module->get().getOps<mlir::FuncOp>()) {
if (auto tf_attrs =
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
// TODO(jaesung): There could be multiple entry functions. Let's handle
// such cases if there are any needs for that.
if (entry_function != nullptr) {
return errors::InvalidArgument(
"There should be only one tf.entry_function");
}
entry_function = func;
}
}
if (entry_function == nullptr) {
return errors::InvalidArgument("no tf.entry_function found");
}
// Get the list of input Op names from the function attribute.
mlir::DictionaryAttr tf_attrs =
entry_function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
llvm::SmallVector<llvm::StringRef, 4> function_input_names;
function_input_names.reserve(model_flags.input_arrays().size());
auto input_attr = tf_attrs.get("inputs");
if (!input_attr) {
return errors::InvalidArgument("no inputs attribute found");
}
auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
input_names.split(function_input_names, ",");
if (function_input_names.size() != model_flags.input_arrays().size()) {
return errors::InvalidArgument(
"input array size mismatch: got ", function_input_names.size(),
", expected: ", model_flags.input_arrays().size());
}
llvm::StringSet<> function_input_names_set;
function_input_names_set.insert(function_input_names.begin(),
function_input_names.end());
for (const auto& input_array : model_flags.input_arrays()) {
if (function_input_names_set.count(input_array.name()) == 0) {
return errors::InvalidArgument("input array name (", input_array.name(),
") does not exist in the given graph");
}
}
// Get the list of output Op names from the function attribute.
llvm::SmallVector<llvm::StringRef, 4> function_output_names;
function_output_names.reserve(model_flags.output_arrays().size());
auto output_attr = tf_attrs.get("outputs");
if (!output_attr) {
return errors::InvalidArgument("no outputs attribute found");
}
auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
output_names.split(function_output_names, ",");
if (function_output_names.size() != model_flags.output_arrays().size()) {
return errors::InvalidArgument(
"output array size mismatch: got ", function_output_names.size(),
", expected: ", model_flags.output_arrays().size());
}
llvm::StringSet<> function_output_names_set;
function_output_names_set.insert(function_output_names.begin(),
function_output_names.end());
for (const auto& output_array : model_flags.output_arrays()) {
if (function_output_names_set.count(output_array) == 0) {
return errors::InvalidArgument("output array name (", output_array,
") does not exist in the given graph");
}
}
return Status::OK();
}
Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
string* result) {
@ -77,6 +151,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
model_flags.saved_model_version(), tags,
exported_names, &context));
if (!model_flags.input_arrays().empty() ||
!model_flags.output_arrays().empty()) {
TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
}
mlir::TFL::PassConfig pass_config(quant_specs);
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;

View File

@ -400,7 +400,10 @@ def build_toco_convert_protos(input_tensors,
model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
input_array = model.input_arrays.add()
input_array.name = util.get_tensor_name(input_tensor)
if saved_model_dir:
input_array.name = input_tensor.name
else:
input_array.name = util.get_tensor_name(input_tensor)
input_array.data_type = util.convert_dtype_to_tflite_type(
input_tensor.dtype)
@ -423,7 +426,10 @@ def build_toco_convert_protos(input_tensors,
input_array.shape.dims.extend(dims)
for output_tensor in output_tensors:
model.output_arrays.append(util.get_tensor_name(output_tensor))
if saved_model_dir:
model.output_arrays.append(output_tensor.name)
else:
model.output_arrays.append(util.get_tensor_name(output_tensor))
model.allow_nonexistent_arrays = allow_nonexistent_arrays

View File

@ -382,6 +382,9 @@ class TFLiteConverterBase(object):
def _parse_saved_model_args(self):
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
if not self.experimental_new_converter:
self._saved_model_dir = None
return
if self._saved_model_dir:
try:
saved_model_proto, _ = (

View File

@ -1634,19 +1634,19 @@ class FromSavedModelTest(TestModels):
input_details = interpreter.get_input_details()
self.assertEqual(2, len(input_details))
self.assertEqual('inputA', input_details[0]['name'])
self.assertStartsWith(input_details[0]['name'], 'inputA')
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization'])
self.assertEqual('inputB', input_details[1]['name'])
self.assertStartsWith(input_details[1]['name'], 'inputB')
self.assertEqual(np.float32, input_details[1]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
self.assertEqual((0., 0.), input_details[1]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
self.assertStartsWith(output_details[0]['name'], 'add')
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
@ -1696,19 +1696,19 @@ class FromSavedModelTest(TestModels):
input_details = interpreter.get_input_details()
self.assertEqual(2, len(input_details))
self.assertEqual('inputA', input_details[0]['name'])
self.assertStartsWith(input_details[0]['name'], 'inputA')
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization'])
self.assertEqual('inputB', input_details[1]['name'])
self.assertStartsWith(input_details[1]['name'], 'inputB')
self.assertEqual(np.float32, input_details[1]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
self.assertEqual((0., 0.), input_details[1]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
self.assertStartsWith(output_details[0]['name'], 'add')
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
@ -1728,19 +1728,19 @@ class FromSavedModelTest(TestModels):
input_details = interpreter.get_input_details()
self.assertEqual(2, len(input_details))
self.assertEqual('inputA', input_details[0]['name'])
self.assertStartsWith(input_details[0]['name'], 'inputA')
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization'])
self.assertEqual('inputB', input_details[1]['name'])
self.assertStartsWith(input_details[1]['name'], 'inputB')
self.assertEqual(np.float32, input_details[1]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
self.assertEqual((0., 0.), input_details[1]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
self.assertStartsWith(output_details[0]['name'], 'add')
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])

View File

@ -373,19 +373,22 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
input_details = interpreter.get_input_details()
self.assertLen(input_details, 2)
self.assertEqual('inputA', input_details[0]['name'])
self.assertStartsWith(input_details[0]['name'], 'inputA')
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization'])
self.assertEqual('inputB', input_details[1]['name'])
self.assertStartsWith(
input_details[1]['name'],
'inputB',
)
self.assertEqual(np.float32, input_details[1]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
self.assertEqual((0., 0.), input_details[1]['quantization'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual('add', output_details[0]['name'])
self.assertStartsWith(output_details[0]['name'], 'add')
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])