Add check for control flow v1 in converter. If found any control flow v1 in the

graph after importing then an error message will show and failure is signaled.

PiperOrigin-RevId: 317938209
Change-Id: I58e14a25dad2f2337d8ad05de55aff757dfcd0b2
This commit is contained in:
Karim Nosir 2020-06-23 14:19:01 -07:00 committed by TensorFlower Gardener
parent cc9d951afa
commit b88bebf1ed
5 changed files with 297 additions and 7 deletions

View File

@ -26,6 +26,7 @@ filegroup(
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -0,0 +1,257 @@
# RUN: not tf_tfl_translate -tf-upgrade-legacy=false -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=1,2:1 -tf-output-arrays=cond/Merge -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo -output-mlir %s -o - 2>&1 | FileCheck %s
# CHECK: error: The graph has Control Flow V1 ops. TFLite converter doesn't support Control Flow V1 ops. Consider using Control Flow V2 ops instead.
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 2
}
dim {
size: 2
}
}
tensor_content: "\315\314\314=\315\314L>\232\231\231>\315\314\314>"
}
}
}
}
node {
name: "Placeholder"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 2
}
}
}
}
}
node {
name: "Placeholder_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "cond/Switch"
op: "Switch"
input: "Placeholder_1"
input: "Placeholder_1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/switch_t"
op: "Identity"
input: "cond/Switch:1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/switch_f"
op: "Identity"
input: "cond/Switch"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/pred_id"
op: "Identity"
input: "Placeholder_1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/MatMul"
op: "MatMul"
input: "cond/MatMul/Switch:1"
input: "cond/MatMul/Switch_1:1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "transpose_a"
value {
b: false
}
}
attr {
key: "transpose_b"
value {
b: false
}
}
}
node {
name: "cond/MatMul/Switch"
op: "Switch"
input: "Placeholder"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Placeholder"
}
}
}
}
node {
name: "cond/MatMul/Switch_1"
op: "Switch"
input: "Const"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Const"
}
}
}
}
node {
name: "cond/Add"
op: "Add"
input: "cond/Add/Switch"
input: "cond/Add/Switch_1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "cond/Add/Switch"
op: "Switch"
input: "Placeholder"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Placeholder"
}
}
}
}
node {
name: "cond/Add/Switch_1"
op: "Switch"
input: "Const"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Const"
}
}
}
}
node {
name: "cond/Merge"
op: "Merge"
input: "cond/Add"
input: "cond/MatMul"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "init"
op: "NoOp"
}
versions {
producer: 134
}

View File

@ -172,7 +172,7 @@ int main(int argc, char **argv) {
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays,
/*prune_unused_nodes=*/true, &source_mgr, &context);
/*prune_unused_nodes=*/true, upgrade_legacy, &source_mgr, &context);
}
// If errors occur, the library call in the above already logged the error

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
@ -39,19 +41,47 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace {
using mlir::MLIRContext;
using mlir::ModuleOp;
using mlir::Operation;
using mlir::OwningModuleRef;
using stream_executor::port::StatusOr;
bool IsControlFlowV1Op(Operation* op) {
return mlir::isa<mlir::tf_executor::SwitchOp>(op) ||
mlir::isa<mlir::tf_executor::MergeOp>(op) ||
mlir::isa<mlir::tf_executor::EnterOp>(op) ||
mlir::isa<mlir::tf_executor::ExitOp>(op) ||
mlir::isa<mlir::tf_executor::NextIterationSinkOp>(op) ||
mlir::isa<mlir::tf_executor::NextIterationSourceOp>(op);
}
mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
auto result = module.walk([&](Operation* op) {
return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
: mlir::WalkResult::advance();
});
if (result.wasInterrupted()) {
module.emitError(
"The graph has Control Flow V1 ops. TFLite converter doesn't support "
"Control Flow V1 ops. Consider using Control Flow V2 ops instead. See "
"https://www.tensorflow.org/api_docs/python/tf/compat/v1/"
"enable_control_flow_v2.");
return mlir::failure();
}
return mlir::success();
}
} // namespace
StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
const std::string& input_filename, bool input_mlir,
bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
llvm::SourceMgr* source_mgr, MLIRContext* context) {
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
MLIRContext* context) {
// Set up the input file.
std::string error_message;
auto file = mlir::openInputFile(input_filename, &error_message);
@ -86,14 +116,14 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, /*control_output_arrays=*/"",
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
/*graph_as_function=*/false, enable_upgrade_legacy,
/*enable_shape_inference=*/false, context);
}
return tensorflow::GraphdefToMlirTranslateFunction(
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, /*control_output_arrays=*/"",
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
/*graph_as_function=*/false, enable_upgrade_legacy,
/*enable_shape_inference=*/false, context);
}
@ -104,7 +134,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::PassManager* pass_manager) {
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
/*propagate=*/true);
if (failed(pass_manager->run(module))) {
if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) {
return statusHandler.ConsumeStatus();
}

View File

@ -41,7 +41,8 @@ LoadFromGraphdefOrMlirSource(
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
mlir::MLIRContext* context);
// Load Saved model (either v1 or v2) into MLIR.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(