Provide user friendly errors when upgrading legacy control flow ops is failed

PiperOrigin-RevId: 347731417
Change-Id: Id9ddaf308be4b9b13d9aa266f6dda9165fb4f4a6
This commit is contained in:
Jaesung Chung 2020-12-15 18:07:55 -08:00 committed by TensorFlower Gardener
parent 442824b76e
commit e1cb1f17ee
8 changed files with 123 additions and 4 deletions

View File

@ -190,8 +190,13 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
restrict_functionalization_to_tpu_nodes
? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
: NodeFilter{};
return FunctionalizeControlFlow(graph, flib_def, node_filter,
/*include_functions=*/true);
TF_RETURN_WITH_CONTEXT_IF_ERROR(
FunctionalizeControlFlow(graph, flib_def, node_filter,
/*include_functions=*/true),
"Failed to functionalize Control Flow V1 ops. Consider using Control "
"Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
"compat/v1/enable_control_flow_v2.");
return Status::OK();
}
// Stateful helper class to import a TensorFlow model into an MLIR Module.

View File

@ -186,7 +186,10 @@ py_library(
py_test(
name = "lite_test",
srcs = ["lite_test.py"],
data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
data = [
"//tensorflow/lite/python/testdata:control_flow_v1.pbtxt",
"@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
],
python_version = "PY3",
shard_count = 4,
srcs_version = "PY2AND3",
@ -205,6 +208,9 @@ py_test(
py_test(
name = "lite_v2_test",
srcs = ["lite_v2_test.py"],
data = [
"//tensorflow/lite/python/testdata/control_flow_v1_saved_model:saved_model.pb",
],
python_version = "PY3",
shard_count = 12,
srcs_version = "PY2AND3",

View File

@ -2740,5 +2740,24 @@ class DefaultConverterAttrsTest(LiteTest):
self.assertIsNone(converter.conversion_summary_dir)
class ControlFlowV1OpsTest(LiteTest):
def testConverterErrorOnControlFlowV1Ops(self):
graph_def_file = resource_loader.get_path_to_datafile(
'testdata/control_flow_v1.pbtxt')
input_arrays = ['a', 'b', 'c', 'd']
output_arrays = ['Merge']
converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
input_arrays,
output_arrays)
with self.assertRaises(ConverterError) as error:
converter.convert()
self.assertIn(
'Failed to functionalize Control Flow V1 ops. Consider using Control '
'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
if __name__ == '__main__':
test.main()

View File

@ -28,6 +28,7 @@ from six.moves import zip
import tensorflow as tf
from tensorflow.lite.kernels.hashtable import pywrap_hashtable_ops as hashtable_ops_registerer
from tensorflow.lite.python import convert
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_v2_test_util
from tensorflow.lite.python.convert import mlir_quantize
@ -38,6 +39,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model import saved_model
@ -1263,6 +1265,18 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
tflite_model, [input_data['x'], input_data['b']])[0]
self.assertAllClose(expected_value, actual_value)
@test_util.run_v2_only
def testConverterErrorOnControlFlowV1Ops(self):
filename = resource_loader.get_path_to_datafile(
'testdata/control_flow_v1_saved_model')
converter = lite.TFLiteConverterV2.from_saved_model(filename)
with self.assertRaises(convert.ConverterError) as error:
converter.convert()
self.assertIn(
'Failed to functionalize Control Flow V1 ops. Consider using Control '
'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
@test_util.run_v2_only
def testStaticRnn(self):
input_data = tf.constant(

View File

@ -12,7 +12,10 @@ package(
licenses = ["notice"], # Apache 2.0,
)
exports_files(glob(["*.pb"]))
exports_files(glob([
"*.pb",
"*.pbtxt",
]))
tf_to_tflite(
name = "permute_float",

View File

@ -0,0 +1,64 @@
node {
name: "a"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node {
name: "b"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node {
name: "c"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node {
name: "d"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node {
name: "Merge"
op: "Merge"
input: "a"
input: "b"
input: "c"
input: "d"
attr {
key: "N"
value {
i: 4
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 27
}

View File

@ -0,0 +1,8 @@
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0,
)
exports_files([
"saved_model.pb",
])