fix unrollbatch matmul toco transformation.
PiperOrigin-RevId: 222176219
This commit is contained in:
parent
3affd7655e
commit
be27c92dde
@ -303,6 +303,7 @@ def generated_test_models():
|
||||
"transpose",
|
||||
"transpose_conv",
|
||||
"unpack",
|
||||
"unroll_batch_matmul",
|
||||
"where",
|
||||
"zeros_like",
|
||||
]
|
||||
|
@ -117,7 +117,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Note that quantized inference requires that all tensors have their
|
||||
// parameters set. This is usually done during quantized training.
|
||||
TfLiteType data_type = input->type;
|
||||
if (data_type != kTfLiteFloat32) {
|
||||
if (data_type != kTfLiteFloat32 && data_type != kTfLiteInt32) {
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, filter, bias, output, &real_multiplier));
|
||||
|
@ -3470,6 +3470,32 @@ def make_logical_xor_tests(zip_path):
|
||||
return _make_logical_tests(tf.logical_xor)(zip_path)
|
||||
|
||||
|
||||
def make_unroll_batch_matmul_tests(zip_path):
|
||||
"""Make a set of tests to test unroll_batch_matmul."""
|
||||
|
||||
test_parameters = [{"dtype": [tf.float32], "shape": [[(2, 2, 3), (2, 3, 2)]]}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the batch_matmul op testing graph."""
|
||||
input_tensor1 = tf.placeholder(
|
||||
dtype=parameters["dtype"], shape=parameters["shape"][0])
|
||||
input_tensor2 = tf.placeholder(
|
||||
dtype=parameters["dtype"], shape=parameters["shape"][1])
|
||||
# Should be unrolled and replaced with fully_connected ops in the end.
|
||||
out = tf.matmul(input_tensor1, input_tensor2)
|
||||
return [input_tensor1, input_tensor2], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value1 = create_tensor_data(
|
||||
parameters["dtype"], shape=parameters["shape"][0])
|
||||
input_value2 = create_tensor_data(
|
||||
parameters["dtype"], shape=parameters["shape"][1])
|
||||
return [input_value1, input_value2], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
# Toco binary path provided by the generate rule.
|
||||
bin_path = None
|
||||
|
||||
|
@ -117,7 +117,8 @@ namespace toco {
|
||||
auto* slice_b_op = new SliceOperator;
|
||||
slice_b_op->inputs = {
|
||||
batch_op->inputs[1],
|
||||
CreateInt32Array(model, batch_name + "/slice_b/slice/begin", {0, 0, 0}),
|
||||
CreateInt32Array(model, batch_name + "/slice_b/slice/begin",
|
||||
{batch, 0, 0}),
|
||||
CreateInt32Array(
|
||||
model, batch_name + "/slice_b/slice/size",
|
||||
{1, input_array_b.shape().dims(1), input_array_b.shape().dims(2)}),
|
||||
|
Loading…
Reference in New Issue
Block a user