Fix BatchMatMulV2 case in TFLite/Toco
PiperOrigin-RevId: 244412567
This commit is contained in:
parent
7b6ac0e5d2
commit
844842a6b0
@ -346,7 +346,6 @@ def generated_test_models_failing(conversion_mode):
|
||||
if conversion_mode == "toco-flex":
|
||||
return [
|
||||
"lstm", # TODO(b/117510976): Restore when lstm flex conversion works.
|
||||
"unroll_batch_matmul", # TODO(b/123030774): Fails in 1.13 tests.
|
||||
"unidirectional_sequence_lstm",
|
||||
"unidirectional_sequence_rnn",
|
||||
]
|
||||
|
@ -76,6 +76,16 @@ KNOWN_BUGS = {
|
||||
r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
|
||||
# Div will use floordiv.
|
||||
r"div.*int32": "72051395",
|
||||
|
||||
# TFLite/Toco does not support BatchMatMul(V2) broadcasting semantic yet.
|
||||
# Simple broadcast.
|
||||
r"unroll_batch_matmul.*shape=\[\(1,2,3\),\(3,5\).*": "130887526",
|
||||
# Empty batch broadcast.
|
||||
r"unroll_batch_matmul.*shape=\[\(2,5,3\),\(3,7\).*": "130887526",
|
||||
# Single batch with non-empty batch broadcast.
|
||||
r"unroll_batch_matmul.*shape=\[\(1,5,3\),\(4,3,7\).*": "130887526",
|
||||
# Broadcast both operands
|
||||
r"unroll_batch_matmul.*shape=\[\(3,1,5,3\),\(1,4,3,7\).*": "130887526",
|
||||
}
|
||||
|
||||
|
||||
@ -4376,31 +4386,58 @@ def make_mirror_pad_tests(options):
|
||||
def make_unroll_batch_matmul_tests(options):
|
||||
"""Make a set of tests to test unroll_batch_matmul."""
|
||||
|
||||
# The test cases below requires broadcasting support (BatchMatMulV2 semantic),
|
||||
# whis isn't supported as of this change.
|
||||
broadcast_shape_params = [
|
||||
# Simple broadcast.
|
||||
[(1, 2, 3), (3, 5), False, False],
|
||||
# Empty batch broadcast.
|
||||
[(2, 5, 3), (3, 7), False, False],
|
||||
# Single batch with non-empty batch broadcast.
|
||||
[(1, 5, 3), (4, 3, 7), False, False],
|
||||
# Broadcast both operands
|
||||
[(3, 1, 5, 3), (1, 4, 3, 7), False, False],
|
||||
]
|
||||
|
||||
test_parameters = [{
|
||||
"dtype": [tf.float32],
|
||||
"shape": [[(2, 2, 3), (2, 3, 2), False, False],
|
||||
[(2, 2, 3), (2, 3, 2), True, True],
|
||||
[(2, 2, 3), (2, 2, 3), False, True],
|
||||
[(2, 2, 3), (2, 2, 3), True, False],
|
||||
[(4, 2, 2, 3), (4, 2, 3, 2), False, False],
|
||||
[(4, 2, 2, 3), (4, 2, 3, 2), True, True],
|
||||
[(4, 2, 2, 3), (4, 2, 2, 3), False, True],
|
||||
[(4, 2, 2, 3), (4, 2, 2, 3), True, False]]
|
||||
"shape": [
|
||||
[(2, 2, 3), (2, 3, 2), False, False],
|
||||
[(2, 2, 3), (2, 3, 2), True, True],
|
||||
[(2, 2, 3), (2, 2, 3), False, True],
|
||||
[(2, 2, 3), (2, 2, 3), True, False],
|
||||
[(4, 2, 2, 3), (4, 2, 3, 2), False, False],
|
||||
[(4, 2, 2, 3), (4, 2, 3, 2), True, True],
|
||||
[(4, 2, 2, 3), (4, 2, 2, 3), False, True],
|
||||
[(4, 2, 2, 3), (4, 2, 2, 3), True, False]
|
||||
] + broadcast_shape_params,
|
||||
# TODO(b/130887442): Improve the forward compatibility tests for every
|
||||
# ops.
|
||||
"forward_compatibility_test": [False, True],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the batch_matmul op testing graph."""
|
||||
input_tensor1 = tf.placeholder(
|
||||
dtype=parameters["dtype"], shape=parameters["shape"][0])
|
||||
input_tensor2 = tf.placeholder(
|
||||
dtype=parameters["dtype"], shape=parameters["shape"][1])
|
||||
# Should be unrolled and replaced with fully_connected ops in the end.
|
||||
out = tf.matmul(
|
||||
input_tensor1,
|
||||
input_tensor2,
|
||||
transpose_a=parameters["shape"][2],
|
||||
transpose_b=parameters["shape"][3])
|
||||
return [input_tensor1, input_tensor2], [out]
|
||||
def _build_graph():
|
||||
input_tensor1 = tf.placeholder(
|
||||
dtype=parameters["dtype"], shape=parameters["shape"][0])
|
||||
input_tensor2 = tf.placeholder(
|
||||
dtype=parameters["dtype"], shape=parameters["shape"][1])
|
||||
# Should be unrolled and replaced with fully_connected ops in the end.
|
||||
out = tf.matmul(
|
||||
input_tensor1,
|
||||
input_tensor2,
|
||||
transpose_a=parameters["shape"][2],
|
||||
transpose_b=parameters["shape"][3])
|
||||
return [input_tensor1, input_tensor2], [out]
|
||||
if parameters["forward_compatibility_test"]:
|
||||
# This is hardcoded to the date after MatMulV2 is activated.
|
||||
# TODO(b/130887442): Improve the forward compatibility tests for every
|
||||
# ops, and remove the hardcoded date.
|
||||
with tf.compat.forward_compatibility_horizon(2019, 4, 26):
|
||||
return _build_graph()
|
||||
else:
|
||||
return _build_graph()
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value1 = create_tensor_data(
|
||||
@ -4410,7 +4447,9 @@ def make_unroll_batch_matmul_tests(options):
|
||||
return [input_value1, input_value2], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
|
||||
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
make_zip_of_tests(
|
||||
options, test_parameters, build_graph, build_inputs,
|
||||
expected_tf_failures=len(broadcast_shape_params))
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
|
@ -2427,6 +2427,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
|
||||
{"AvgPool", ConvertAvgPoolOperator},
|
||||
{"BatchMatMul", ConvertBatchMatMulOperator},
|
||||
{"BatchMatMulV2", ConvertBatchMatMulOperator},
|
||||
{"BatchNormWithGlobalNormalization",
|
||||
ConvertBatchNormWithGlobalNormalizationOperator},
|
||||
{"BatchToSpaceND", ConvertBatchToSpaceNDOperator},
|
||||
|
@ -58,6 +58,7 @@ bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name) {
|
||||
"AvgPool3D",
|
||||
"AvgPoolGrad",
|
||||
"BatchMatMul",
|
||||
"BatchMatMulV2",
|
||||
"BatchNormWithGlobalNormalization",
|
||||
"BatchNormWithGlobalNormalizationGrad",
|
||||
"BatchToSpace",
|
||||
|
Loading…
x
Reference in New Issue
Block a user