Provide user friendly errors when upgrading legacy control flow ops is failed
PiperOrigin-RevId: 347731417 Change-Id: Id9ddaf308be4b9b13d9aa266f6dda9165fb4f4a6
This commit is contained in:
parent
442824b76e
commit
e1cb1f17ee
@ -190,8 +190,13 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
|
||||
restrict_functionalization_to_tpu_nodes
|
||||
? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
|
||||
: NodeFilter{};
|
||||
return FunctionalizeControlFlow(graph, flib_def, node_filter,
|
||||
/*include_functions=*/true);
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
FunctionalizeControlFlow(graph, flib_def, node_filter,
|
||||
/*include_functions=*/true),
|
||||
"Failed to functionalize Control Flow V1 ops. Consider using Control "
|
||||
"Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
|
||||
"compat/v1/enable_control_flow_v2.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Stateful helper class to import a TensorFlow model into an MLIR Module.
|
||||
|
@ -186,7 +186,10 @@ py_library(
|
||||
py_test(
|
||||
name = "lite_test",
|
||||
srcs = ["lite_test.py"],
|
||||
data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
|
||||
data = [
|
||||
"//tensorflow/lite/python/testdata:control_flow_v1.pbtxt",
|
||||
"@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
|
||||
],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
@ -205,6 +208,9 @@ py_test(
|
||||
py_test(
|
||||
name = "lite_v2_test",
|
||||
srcs = ["lite_v2_test.py"],
|
||||
data = [
|
||||
"//tensorflow/lite/python/testdata/control_flow_v1_saved_model:saved_model.pb",
|
||||
],
|
||||
python_version = "PY3",
|
||||
shard_count = 12,
|
||||
srcs_version = "PY2AND3",
|
||||
|
@ -2740,5 +2740,24 @@ class DefaultConverterAttrsTest(LiteTest):
|
||||
self.assertIsNone(converter.conversion_summary_dir)
|
||||
|
||||
|
||||
class ControlFlowV1OpsTest(LiteTest):
|
||||
|
||||
def testConverterErrorOnControlFlowV1Ops(self):
|
||||
graph_def_file = resource_loader.get_path_to_datafile(
|
||||
'testdata/control_flow_v1.pbtxt')
|
||||
input_arrays = ['a', 'b', 'c', 'd']
|
||||
output_arrays = ['Merge']
|
||||
|
||||
converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
|
||||
input_arrays,
|
||||
output_arrays)
|
||||
with self.assertRaises(ConverterError) as error:
|
||||
converter.convert()
|
||||
self.assertIn(
|
||||
'Failed to functionalize Control Flow V1 ops. Consider using Control '
|
||||
'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
|
||||
'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -28,6 +28,7 @@ from six.moves import zip
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.lite.kernels.hashtable import pywrap_hashtable_ops as hashtable_ops_registerer
|
||||
from tensorflow.lite.python import convert
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_v2_test_util
|
||||
from tensorflow.lite.python.convert import mlir_quantize
|
||||
@ -38,6 +39,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
@ -1263,6 +1265,18 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
tflite_model, [input_data['x'], input_data['b']])[0]
|
||||
self.assertAllClose(expected_value, actual_value)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testConverterErrorOnControlFlowV1Ops(self):
|
||||
filename = resource_loader.get_path_to_datafile(
|
||||
'testdata/control_flow_v1_saved_model')
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(filename)
|
||||
with self.assertRaises(convert.ConverterError) as error:
|
||||
converter.convert()
|
||||
self.assertIn(
|
||||
'Failed to functionalize Control Flow V1 ops. Consider using Control '
|
||||
'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
|
||||
'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testStaticRnn(self):
|
||||
input_data = tf.constant(
|
||||
|
5
tensorflow/lite/python/testdata/BUILD
vendored
5
tensorflow/lite/python/testdata/BUILD
vendored
@ -12,7 +12,10 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0,
|
||||
)
|
||||
|
||||
exports_files(glob(["*.pb"]))
|
||||
exports_files(glob([
|
||||
"*.pb",
|
||||
"*.pbtxt",
|
||||
]))
|
||||
|
||||
tf_to_tflite(
|
||||
name = "permute_float",
|
||||
|
64
tensorflow/lite/python/testdata/control_flow_v1.pbtxt
vendored
Normal file
64
tensorflow/lite/python/testdata/control_flow_v1.pbtxt
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
node {
|
||||
name: "a"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "b"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "c"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "d"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Merge"
|
||||
op: "Merge"
|
||||
input: "a"
|
||||
input: "b"
|
||||
input: "c"
|
||||
input: "d"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 4
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
8
tensorflow/lite/python/testdata/control_flow_v1_saved_model/BUILD
vendored
Normal file
8
tensorflow/lite/python/testdata/control_flow_v1_saved_model/BUILD
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0,
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"saved_model.pb",
|
||||
])
|
BIN
tensorflow/lite/python/testdata/control_flow_v1_saved_model/saved_model.pb
vendored
Normal file
BIN
tensorflow/lite/python/testdata/control_flow_v1_saved_model/saved_model.pb
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user