diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 2860de6ad26..c52a3e5dc0b 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -76,16 +76,6 @@ KNOWN_BUGS = { r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", # Div will use floordiv. r"div.*int32": "72051395", - - # TFLite/Toco does not support BatchMatMul(V2) broadcasting semantic yet. - # Simple broadcast. - r"unroll_batch_matmul.*shape=\[\(1,2,3\),\(3,5\).*": "130887526", - # Empty batch broadcast. - r"unroll_batch_matmul.*shape=\[\(2,5,3\),\(3,7\).*": "130887526", - # Single batch with non-empty batch broadcast. - r"unroll_batch_matmul.*shape=\[\(1,5,3\),\(4,3,7\).*": "130887526", - # Broadcast both operands - r"unroll_batch_matmul.*shape=\[\(3,1,5,3\),\(1,4,3,7\).*": "130887526", } diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 8481b0b754c..e24c014acea 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -283,9 +283,11 @@ cc_library( ":runtime", ":toco_port", ":tooling_util", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:strided_slice_logic", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc index 7492f3e116c..50087b17267 100644 --- a/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc +++ b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -18,147 +18,75 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/matmul_bcast.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" namespace toco { - namespace { -void UnrollBatchMatMul3D( - const string& input_lhs, const string& input_rhs, - const BatchMatMulOperator* batch_op, const std::vector batch, - Model* model, std::vector>::iterator* tail_it, - std::vector* pack_inputs) { - const std::string batch_name = - absl::StrCat(batch_op->outputs[0], "_b", absl::StrJoin(batch, "-")); - const auto& input_array_a = model->GetArray(input_lhs); - const auto& input_array_b = model->GetArray(input_rhs); - const int dims_count = input_array_a.shape().dimensions_count(); - - // tf.slice(a, ...). - std::vector begin_indices_a = batch; - begin_indices_a.resize(dims_count); - std::vector slice_size_a = input_array_a.shape().dims(); - for (int i = 0; i < batch.size(); ++i) { - slice_size_a[i] = 1; - } - auto* slice_a_op = new SliceOperator; - slice_a_op->inputs = { - input_lhs, - CreateInt32Array(model, batch_name + "/slice_a/slice/begin", - begin_indices_a), - CreateInt32Array(model, batch_name + "/slice_a/slice/size", slice_size_a), - }; - slice_a_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_a")}; - auto& slice_a_op_output = model->GetOrCreateArray(slice_a_op->outputs[0]); - slice_a_op_output.data_type = input_array_a.data_type; - *tail_it = model->operators.emplace(*tail_it, slice_a_op) + 1; - - // Reshape to remove the first dimension ([1,M,N] -> [M,N]). - auto* slice_a_reshape_op = new TensorFlowReshapeOperator; - slice_a_reshape_op->inputs = { - slice_a_op->outputs[0], - CreateInt32Array(model, batch_name + "/slice_a/reshape/shape", - {-1, input_array_a.shape().dims(dims_count - 1)})}; - slice_a_reshape_op->outputs = { - AvailableArrayName(*model, batch_name + "/slice_a/reshape")}; - auto& slice_a_reshape_op_output = - model->GetOrCreateArray(slice_a_reshape_op->outputs[0]); - slice_a_reshape_op_output.data_type = input_array_a.data_type; - *tail_it = model->operators.emplace(*tail_it, slice_a_reshape_op) + 1; - - // tf.slice(b, ...). - std::vector begin_indices_b = batch; - begin_indices_b.resize(dims_count); - std::vector slice_size_b = input_array_b.shape().dims(); - for (int i = 0; i < batch.size(); ++i) { - slice_size_b[i] = 1; - } - auto* slice_b_op = new SliceOperator; - slice_b_op->inputs = { - input_rhs, - CreateInt32Array(model, batch_name + "/slice_b/slice/begin", - begin_indices_b), - CreateInt32Array(model, batch_name + "/slice_b/slice/size", slice_size_b), - }; - slice_b_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_b")}; - auto& slice_b_op_output = model->GetOrCreateArray(slice_b_op->outputs[0]); - slice_b_op_output.data_type = input_array_b.data_type; - *tail_it = model->operators.emplace(*tail_it, slice_b_op) + 1; - - // Reshape to remove the first dimension ([1,M,N] -> [M,N]). - auto* slice_b_reshape_op = new TensorFlowReshapeOperator; - slice_b_reshape_op->inputs = { - slice_b_op->outputs[0], - CreateInt32Array(model, batch_name + "/slice_b/reshape/shape", - {-1, input_array_b.shape().dims(dims_count - 1)})}; - slice_b_reshape_op->outputs = { - AvailableArrayName(*model, batch_name + "/slice_b/reshape")}; - auto& slice_b_reshape_op_output = - model->GetOrCreateArray(slice_b_reshape_op->outputs[0]); - slice_b_reshape_op_output.data_type = input_array_b.data_type; - *tail_it = model->operators.emplace(*tail_it, slice_b_reshape_op) + 1; - - // tf.matmul(slice_a, slice_b). - auto* matmul_op = new TensorFlowMatMulOperator; - matmul_op->inputs = {slice_a_reshape_op->outputs[0], - slice_b_reshape_op->outputs[0]}; - matmul_op->outputs = {AvailableArrayName(*model, batch_name)}; - auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]); - matmul_op_output.data_type = input_array_a.data_type; - *tail_it = model->operators.emplace(*tail_it, matmul_op) + 1; - - // Add to stack. - pack_inputs->push_back(matmul_op->outputs[0]); +absl::InlinedVector ToInlinedVector(const std::vector& vec) { + return absl::InlinedVector(vec.begin(), vec.end()); } -std::vector UnrollBatchMatMulRecursion( - const string& input_lhs, const string& input_rhs, - const BatchMatMulOperator* batch_op, Model* model, - std::vector>::iterator* tail_it, - const std::vector& batch_prefix) { - const auto& input_array_a = model->GetArray(input_lhs); - const auto& dims_vec = input_array_a.shape().dims(); - const int current_dim_size = dims_vec[batch_prefix.size()]; - std::vector batch_pack_inputs; +std::vector SliceInput( + const string& input, const string& base_name, const string& input_name, + const int batch_size, const Array& input_array, Model* model, + std::vector>::iterator* tail_it) { + int rank = input_array.shape().dimensions_count(); + int num_rows = input_array.shape().dims(rank - 2); + int num_cols = input_array.shape().dims(rank - 1); + // Reshape to rank-3 Tensor with first dimension as the batch size. + auto* reshape_op = new TensorFlowReshapeOperator; + reshape_op->inputs = { + input, + CreateInt32Array(model, absl::StrCat(base_name, "/reshape_a/shape"), + {batch_size, num_rows, num_cols})}; + reshape_op->outputs = {AvailableArrayName( + *model, absl::StrCat(base_name, "/reshape_", input_name, "/reshape"))}; + auto& reshape_op_output = model->GetOrCreateArray(reshape_op->outputs[0]); + reshape_op_output.data_type = input_array.data_type; + *tail_it = model->operators.emplace(*tail_it, reshape_op) + 1; - if (batch_prefix.size() + 3 == dims_vec.size()) { - // Base case - for (int batch = 0; batch < current_dim_size; ++batch) { - std::vector new_batch_prefix = batch_prefix; - new_batch_prefix.emplace_back(batch); - UnrollBatchMatMul3D(input_lhs, input_rhs, batch_op, new_batch_prefix, - model, tail_it, &batch_pack_inputs); - } - } else { - // Recursion - for (int batch = 0; batch < current_dim_size; ++batch) { - std::vector new_batch_prefix = batch_prefix; - new_batch_prefix.emplace_back(batch); - std::vector pack_inputs = UnrollBatchMatMulRecursion( - input_lhs, input_rhs, batch_op, model, tail_it, new_batch_prefix); + // Slice along each batch index and remember the slice output for future use. + std::vector slice_outputs; + for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + std::string batch_name = + absl::StrCat(base_name, "_b", batch_idx, "/slice_", input_name); + auto* slice_op = new SliceOperator; + slice_op->inputs = { + reshape_op->outputs[0], + CreateInt32Array(model, absl::StrCat(batch_name, "/slice/begin"), + {batch_idx, 0, 0}), + CreateInt32Array(model, absl::StrCat(batch_name, "/slice/size"), + {1, num_rows, num_cols})}; + slice_op->outputs = { + AvailableArrayName(*model, absl::StrCat(batch_name, "/slice"))}; + auto& slice_op_output = model->GetOrCreateArray(slice_op->outputs[0]); + slice_op_output.data_type = input_array.data_type; + *tail_it = model->operators.emplace(*tail_it, slice_op) + 1; - // The pack that will join all the individual matmul results together. - auto* pack_op = new PackOperator; - std::string batch_name = absl::StrCat( - batch_op->outputs[0], "_b", absl::StrJoin(new_batch_prefix, "-")); - pack_op->inputs = pack_inputs; - pack_op->outputs = {AvailableArrayName(*model, batch_name + "/pack")}; - auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]); - pack_op_output.data_type = input_array_a.data_type; - pack_op->axis = 0; - pack_op->values_count = pack_inputs.size(); - *tail_it = model->operators.emplace(*tail_it, pack_op) + 1; + // Reshape to rank-2: [1, num_rows, num_cols] -> [num_rows, num_cols]. + auto* slice_reshape_op = new TensorFlowReshapeOperator; + slice_reshape_op->inputs = { + slice_op->outputs[0], + CreateInt32Array(model, absl::StrCat(batch_name, "/reshape/shape"), + {num_rows, num_cols})}; + slice_reshape_op->outputs = { + AvailableArrayName(*model, absl::StrCat(batch_name, "/reshape"))}; + auto& slice_reshape_op_output = + model->GetOrCreateArray(slice_reshape_op->outputs[0]); + slice_reshape_op_output.data_type = input_array.data_type; + *tail_it = model->operators.emplace(*tail_it, slice_reshape_op) + 1; - batch_pack_inputs.push_back(pack_op->outputs[0]); - } + slice_outputs.push_back(slice_reshape_op->outputs[0]); } - return batch_pack_inputs; + return slice_outputs; } std::vector GetTransposePerm(const Array& input_array) { @@ -202,15 +130,6 @@ TransposeOperator* TransposeInput(const string& input, Model* model) { // Unrolls a BatchMatMul on the batch dimension. // We need to slice each batch out of the inputs, matmul them individually, then // stack them all back together at the end. -// -// This transform effectively looks like: -// result_slices = [] -// for bat in B: -// slice_a = tf.reshape(tf.slice(a, [bat, 0, 0], [1, M, N]), [M, N]) -// slice_b = tf.reshape(tf.slice(b, [bat, 0, 0], [1, M, N]), [M, N]) -// slice_c = tf.matmul(slice_a, slice_b) -// result_slices[bat] = slice_c -// result = tf.stack(result_slices) ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index, bool* modified) { *modified = false; @@ -220,7 +139,6 @@ TransposeOperator* TransposeInput(const string& input, Model* model) { } const auto* batch_op = static_cast(batch_op_it->get()); - auto& tail_it = batch_op_it; string input_lhs = batch_op->inputs[0]; @@ -246,20 +164,25 @@ TransposeOperator* TransposeInput(const string& input, Model* model) { } const auto& input_array_b = model->GetArray(input_rhs); - const int dims = input_array_a.shape().dimensions_count(); - for (int i = 0; i < dims - 2; ++i) { - CHECK_EQ(input_array_a.shape().dims(i), input_array_b.shape().dims(i)) - << "input array not consistent at index " << i; - } - CHECK_EQ(input_array_a.shape().dims(dims - 1), - input_array_b.shape().dims(dims - 2)) + // Ensure that input ranks are at least 2 and batch shapes are broadcastable. + const int dims_a = input_array_a.shape().dimensions_count(); + const int dims_b = input_array_b.shape().dimensions_count(); + CHECK_GE(dims_a, 2) << "First input must have rank >= 2"; + CHECK_GE(dims_b, 2) << "Second input must have rank >= 2"; + + ::tensorflow::MatMulBCast bcast( + ToInlinedVector(input_array_a.shape().dims()), + ToInlinedVector(input_array_b.shape().dims())); + CHECK(bcast.IsValid()) << "Input batch dimensions must be broadcastable"; + + CHECK_EQ(input_array_a.shape().dims(dims_a - 1), + input_array_b.shape().dims(dims_b - 2)) << "Input dimensions must be compatible for multipication. shape a = [" << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = [" << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]"; - if (dims == 2) { - // This is really just a MatMul. This likely means that someone hand-crafted - // a graphdef with a BatchMatMul when they really wanted a MatMul. + if (dims_a == 2 && dims_b == 2) { + // This is really just a MatMul. AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator", LogName(*batch_op)); auto* matmul_op = new TensorFlowMatMulOperator; @@ -271,23 +194,65 @@ TransposeOperator* TransposeInput(const string& input, Model* model) { *modified = true; return ::tensorflow::Status::OK(); } - - CHECK_GE(input_array_a.shape().dimensions_count(), 3) - << "Input arrays must have rank >= 3"; - - const auto& dims_vec = input_array_a.shape().dims(); AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op), - std::accumulate(dims_vec.begin(), dims_vec.end() - 2, 1, - std::multiplies())); + bcast.output_batch_size()); + string base_name = std::string(batch_op->outputs[0]); - std::vector pack_inputs = UnrollBatchMatMulRecursion( - input_lhs, input_rhs, batch_op, model, &tail_it, {}); + // Compute slices for each batch in the LHS and RHS. + std::vector slice_a_outputs = + SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a, + model, &tail_it); + std::vector slice_b_outputs = + SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b, + model, &tail_it); + + // Compute (single batch) MatMul for each output batch. The MatMul outputs are + // then packed together into one output Tensor. + std::vector pack_inputs; + for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) { + std::string batch_name = + absl::StrCat(batch_op->outputs[0], "_b", batch_idx); + const int a_batch_idx = bcast.IsBroadcastingRequired() + ? bcast.x_batch_indices()[batch_idx] + : batch_idx; + const int b_batch_idx = bcast.IsBroadcastingRequired() + ? bcast.y_batch_indices()[batch_idx] + : batch_idx; + auto* matmul_op = new TensorFlowMatMulOperator; + matmul_op->inputs = {slice_a_outputs[a_batch_idx], + slice_b_outputs[b_batch_idx]}; + matmul_op->outputs = {AvailableArrayName(*model, batch_name)}; + auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]); + matmul_op_output.data_type = input_array_a.data_type; + tail_it = model->operators.emplace(tail_it, matmul_op) + 1; + + // Add to stack. + pack_inputs.push_back(matmul_op->outputs[0]); + } + + // Combine the result of each individual MatMul into a rank-3 Tensor. auto* pack_op = new PackOperator; pack_op->inputs = pack_inputs; - pack_op->outputs = {batch_op->outputs[0]}; + pack_op->outputs = {AvailableArrayName(*model, base_name + "/pack")}; + auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]); + pack_op_output.data_type = input_array_a.data_type; pack_op->axis = 0; pack_op->values_count = pack_inputs.size(); - model->operators.emplace(tail_it, pack_op); + tail_it = model->operators.emplace(tail_it, pack_op) + 1; + + // Reshape the rank-3 Tensor into the correct output shape. + const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes(); + std::vector result_shape(result_batch_shape.begin(), + result_batch_shape.end()); + result_shape.push_back(input_array_a.shape().dims(dims_a - 2)); + result_shape.push_back(input_array_b.shape().dims(dims_b - 1)); + + auto* reshape_result_op = new TensorFlowReshapeOperator; + reshape_result_op->inputs = { + pack_op->outputs[0], + CreateInt32Array(model, base_name + "/reshape_out/shape", result_shape)}; + reshape_result_op->outputs = {batch_op->outputs[0]}; + model->operators.emplace(tail_it, reshape_result_op); // Remove the old batch matmul now that we've unrolled. batch_op_it = model->operators.begin(); diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index e7318e7082b..2b0c3e982cc 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -971,9 +971,8 @@ struct TensorFlowIdentityOperator : Operator { TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {} }; -// Batch matrix multiplication operator. This comes from the (deprecated) -// tf.batch_matmul or a tf.matmul that has rank 3. dims(0) is the batch count -// and it can be trivially unrolled into a series of matmuls on each element. +// Batch matrix multiplication operator. This comes from a tf.matmul where one +// of the operands has rank 3 or more. // // Inputs: // inputs[0]: required: the left-hand side matrix