Fix BatchMatMulV2 case in TFLite/Toco

PiperOrigin-RevId: 244412567
This commit is contained in:
Yu-Cheng Ling 2019-04-19 13:30:55 -07:00 committed by TensorFlower Gardener
parent 7b6ac0e5d2
commit 844842a6b0
4 changed files with 61 additions and 21 deletions

View File

@ -346,7 +346,6 @@ def generated_test_models_failing(conversion_mode):
if conversion_mode == "toco-flex":
return [
"lstm", # TODO(b/117510976): Restore when lstm flex conversion works.
"unroll_batch_matmul", # TODO(b/123030774): Fails in 1.13 tests.
"unidirectional_sequence_lstm",
"unidirectional_sequence_rnn",
]

View File

@ -76,6 +76,16 @@ KNOWN_BUGS = {
r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
# Div will use floordiv.
r"div.*int32": "72051395",
# TFLite/Toco does not support BatchMatMul(V2) broadcasting semantic yet.
# Simple broadcast.
r"unroll_batch_matmul.*shape=\[\(1,2,3\),\(3,5\).*": "130887526",
# Empty batch broadcast.
r"unroll_batch_matmul.*shape=\[\(2,5,3\),\(3,7\).*": "130887526",
# Single batch with non-empty batch broadcast.
r"unroll_batch_matmul.*shape=\[\(1,5,3\),\(4,3,7\).*": "130887526",
# Broadcast both operands
r"unroll_batch_matmul.*shape=\[\(3,1,5,3\),\(1,4,3,7\).*": "130887526",
}
@ -4376,31 +4386,58 @@ def make_mirror_pad_tests(options):
def make_unroll_batch_matmul_tests(options):
"""Make a set of tests to test unroll_batch_matmul."""
# The test cases below requires broadcasting support (BatchMatMulV2 semantic),
# whis isn't supported as of this change.
broadcast_shape_params = [
# Simple broadcast.
[(1, 2, 3), (3, 5), False, False],
# Empty batch broadcast.
[(2, 5, 3), (3, 7), False, False],
# Single batch with non-empty batch broadcast.
[(1, 5, 3), (4, 3, 7), False, False],
# Broadcast both operands
[(3, 1, 5, 3), (1, 4, 3, 7), False, False],
]
test_parameters = [{
"dtype": [tf.float32],
"shape": [[(2, 2, 3), (2, 3, 2), False, False],
[(2, 2, 3), (2, 3, 2), True, True],
[(2, 2, 3), (2, 2, 3), False, True],
[(2, 2, 3), (2, 2, 3), True, False],
[(4, 2, 2, 3), (4, 2, 3, 2), False, False],
[(4, 2, 2, 3), (4, 2, 3, 2), True, True],
[(4, 2, 2, 3), (4, 2, 2, 3), False, True],
[(4, 2, 2, 3), (4, 2, 2, 3), True, False]]
"shape": [
[(2, 2, 3), (2, 3, 2), False, False],
[(2, 2, 3), (2, 3, 2), True, True],
[(2, 2, 3), (2, 2, 3), False, True],
[(2, 2, 3), (2, 2, 3), True, False],
[(4, 2, 2, 3), (4, 2, 3, 2), False, False],
[(4, 2, 2, 3), (4, 2, 3, 2), True, True],
[(4, 2, 2, 3), (4, 2, 2, 3), False, True],
[(4, 2, 2, 3), (4, 2, 2, 3), True, False]
] + broadcast_shape_params,
# TODO(b/130887442): Improve the forward compatibility tests for every
# ops.
"forward_compatibility_test": [False, True],
}]
def build_graph(parameters):
"""Build the batch_matmul op testing graph."""
input_tensor1 = tf.placeholder(
dtype=parameters["dtype"], shape=parameters["shape"][0])
input_tensor2 = tf.placeholder(
dtype=parameters["dtype"], shape=parameters["shape"][1])
# Should be unrolled and replaced with fully_connected ops in the end.
out = tf.matmul(
input_tensor1,
input_tensor2,
transpose_a=parameters["shape"][2],
transpose_b=parameters["shape"][3])
return [input_tensor1, input_tensor2], [out]
def _build_graph():
input_tensor1 = tf.placeholder(
dtype=parameters["dtype"], shape=parameters["shape"][0])
input_tensor2 = tf.placeholder(
dtype=parameters["dtype"], shape=parameters["shape"][1])
# Should be unrolled and replaced with fully_connected ops in the end.
out = tf.matmul(
input_tensor1,
input_tensor2,
transpose_a=parameters["shape"][2],
transpose_b=parameters["shape"][3])
return [input_tensor1, input_tensor2], [out]
if parameters["forward_compatibility_test"]:
# This is hardcoded to the date after MatMulV2 is activated.
# TODO(b/130887442): Improve the forward compatibility tests for every
# ops, and remove the hardcoded date.
with tf.compat.forward_compatibility_horizon(2019, 4, 26):
return _build_graph()
else:
return _build_graph()
def build_inputs(parameters, sess, inputs, outputs):
input_value1 = create_tensor_data(
@ -4410,7 +4447,9 @@ def make_unroll_batch_matmul_tests(options):
return [input_value1, input_value2], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
make_zip_of_tests(
options, test_parameters, build_graph, build_inputs,
expected_tf_failures=len(broadcast_shape_params))
@register_make_test_function()

View File

@ -2427,6 +2427,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
{"AvgPool", ConvertAvgPoolOperator},
{"BatchMatMul", ConvertBatchMatMulOperator},
{"BatchMatMulV2", ConvertBatchMatMulOperator},
{"BatchNormWithGlobalNormalization",
ConvertBatchNormWithGlobalNormalizationOperator},
{"BatchToSpaceND", ConvertBatchToSpaceNDOperator},

View File

@ -58,6 +58,7 @@ bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name) {
"AvgPool3D",
"AvgPoolGrad",
"BatchMatMul",
"BatchMatMulV2",
"BatchNormWithGlobalNormalization",
"BatchNormWithGlobalNormalizationGrad",
"BatchToSpace",