Add check for control flow v1 in converter. If found any control flow v1 in the
graph after importing then an error message will show and failure is signaled. PiperOrigin-RevId: 317938209 Change-Id: I58e14a25dad2f2337d8ad05de55aff757dfcd0b2
This commit is contained in:
parent
cc9d951afa
commit
b88bebf1ed
@ -26,6 +26,7 @@ filegroup(
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,257 @@
|
||||
# RUN: not tf_tfl_translate -tf-upgrade-legacy=false -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=1,2:1 -tf-output-arrays=cond/Merge -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo -output-mlir %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
# CHECK: error: The graph has Control Flow V1 ops. TFLite converter doesn't support Control Flow V1 ops. Consider using Control Flow V2 ops instead.
|
||||
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
tensor_content: "\315\314\314=\315\314L>\232\231\231>\315\314\314>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Placeholder"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Placeholder_1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Switch"
|
||||
op: "Switch"
|
||||
input: "Placeholder_1"
|
||||
input: "Placeholder_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/switch_t"
|
||||
op: "Identity"
|
||||
input: "cond/Switch:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/switch_f"
|
||||
op: "Identity"
|
||||
input: "cond/Switch"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/pred_id"
|
||||
op: "Identity"
|
||||
input: "Placeholder_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/MatMul"
|
||||
op: "MatMul"
|
||||
input: "cond/MatMul/Switch:1"
|
||||
input: "cond/MatMul/Switch_1:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "transpose_a"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "transpose_b"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/MatMul/Switch"
|
||||
op: "Switch"
|
||||
input: "Placeholder"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Placeholder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/MatMul/Switch_1"
|
||||
op: "Switch"
|
||||
input: "Const"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Const"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Add"
|
||||
op: "Add"
|
||||
input: "cond/Add/Switch"
|
||||
input: "cond/Add/Switch_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Add/Switch"
|
||||
op: "Switch"
|
||||
input: "Placeholder"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Placeholder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Add/Switch_1"
|
||||
op: "Switch"
|
||||
input: "Const"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Const"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Merge"
|
||||
op: "Merge"
|
||||
input: "cond/Add"
|
||||
input: "cond/MatMul"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "init"
|
||||
op: "NoOp"
|
||||
}
|
||||
versions {
|
||||
producer: 134
|
||||
}
|
@ -172,7 +172,7 @@ int main(int argc, char **argv) {
|
||||
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
|
||||
debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays,
|
||||
/*prune_unused_nodes=*/true, &source_mgr, &context);
|
||||
/*prune_unused_nodes=*/true, upgrade_legacy, &source_mgr, &context);
|
||||
}
|
||||
|
||||
// If errors occur, the library call in the above already logged the error
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
@ -39,19 +41,47 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
using mlir::MLIRContext;
|
||||
using mlir::ModuleOp;
|
||||
using mlir::Operation;
|
||||
using mlir::OwningModuleRef;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
bool IsControlFlowV1Op(Operation* op) {
|
||||
return mlir::isa<mlir::tf_executor::SwitchOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::MergeOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::EnterOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::ExitOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::NextIterationSinkOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::NextIterationSourceOp>(op);
|
||||
}
|
||||
|
||||
mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
|
||||
auto result = module.walk([&](Operation* op) {
|
||||
return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
|
||||
: mlir::WalkResult::advance();
|
||||
});
|
||||
if (result.wasInterrupted()) {
|
||||
module.emitError(
|
||||
"The graph has Control Flow V1 ops. TFLite converter doesn't support "
|
||||
"Control Flow V1 ops. Consider using Control Flow V2 ops instead. See "
|
||||
"https://www.tensorflow.org/api_docs/python/tf/compat/v1/"
|
||||
"enable_control_flow_v2.");
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
const std::string& input_filename, bool input_mlir,
|
||||
bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
llvm::SourceMgr* source_mgr, MLIRContext* context) {
|
||||
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
|
||||
MLIRContext* context) {
|
||||
// Set up the input file.
|
||||
std::string error_message;
|
||||
auto file = mlir::openInputFile(input_filename, &error_message);
|
||||
@ -86,14 +116,14 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
||||
/*graph_as_function=*/false, enable_upgrade_legacy,
|
||||
/*enable_shape_inference=*/false, context);
|
||||
}
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
||||
/*graph_as_function=*/false, enable_upgrade_legacy,
|
||||
/*enable_shape_inference=*/false, context);
|
||||
}
|
||||
|
||||
@ -104,7 +134,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
mlir::PassManager* pass_manager) {
|
||||
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
|
||||
/*propagate=*/true);
|
||||
if (failed(pass_manager->run(module))) {
|
||||
|
||||
if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,8 @@ LoadFromGraphdefOrMlirSource(
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
|
||||
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Load Saved model (either v1 or v2) into MLIR.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
|
Loading…
x
Reference in New Issue
Block a user