Add support for full BatchMatMulV2 functionality (broadcasting) in TFLite.

PiperOrigin-RevId: 245290978
This commit is contained in:
Anudhyan Boral 2019-04-25 12:59:46 -07:00 committed by TensorFlower Gardener
parent 87b5f63ba7
commit 17db87c632
4 changed files with 125 additions and 169 deletions

View File

@ -76,16 +76,6 @@ KNOWN_BUGS = {
r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
# Div will use floordiv. # Div will use floordiv.
r"div.*int32": "72051395", r"div.*int32": "72051395",
# TFLite/Toco does not support BatchMatMul(V2) broadcasting semantic yet.
# Simple broadcast.
r"unroll_batch_matmul.*shape=\[\(1,2,3\),\(3,5\).*": "130887526",
# Empty batch broadcast.
r"unroll_batch_matmul.*shape=\[\(2,5,3\),\(3,7\).*": "130887526",
# Single batch with non-empty batch broadcast.
r"unroll_batch_matmul.*shape=\[\(1,5,3\),\(4,3,7\).*": "130887526",
# Broadcast both operands
r"unroll_batch_matmul.*shape=\[\(3,1,5,3\),\(1,4,3,7\).*": "130887526",
} }

View File

@ -283,9 +283,11 @@ cc_library(
":runtime", ":runtime",
":toco_port", ":toco_port",
":tooling_util", ":tooling_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:quantization_util",
"//tensorflow/lite/kernels/internal:strided_slice_logic", "//tensorflow/lite/kernels/internal:strided_slice_logic",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -18,147 +18,75 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/matmul_bcast.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/lite/toco/tooling_util.h"
namespace toco { namespace toco {
namespace { namespace {
void UnrollBatchMatMul3D( absl::InlinedVector<int64, 4> ToInlinedVector(const std::vector<int>& vec) {
const string& input_lhs, const string& input_rhs, return absl::InlinedVector<int64, 4>(vec.begin(), vec.end());
const BatchMatMulOperator* batch_op, const std::vector<int> batch,
Model* model, std::vector<std::unique_ptr<Operator>>::iterator* tail_it,
std::vector<string>* pack_inputs) {
const std::string batch_name =
absl::StrCat(batch_op->outputs[0], "_b", absl::StrJoin(batch, "-"));
const auto& input_array_a = model->GetArray(input_lhs);
const auto& input_array_b = model->GetArray(input_rhs);
const int dims_count = input_array_a.shape().dimensions_count();
// tf.slice(a, ...).
std::vector<int> begin_indices_a = batch;
begin_indices_a.resize(dims_count);
std::vector<int> slice_size_a = input_array_a.shape().dims();
for (int i = 0; i < batch.size(); ++i) {
slice_size_a[i] = 1;
}
auto* slice_a_op = new SliceOperator;
slice_a_op->inputs = {
input_lhs,
CreateInt32Array(model, batch_name + "/slice_a/slice/begin",
begin_indices_a),
CreateInt32Array(model, batch_name + "/slice_a/slice/size", slice_size_a),
};
slice_a_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_a")};
auto& slice_a_op_output = model->GetOrCreateArray(slice_a_op->outputs[0]);
slice_a_op_output.data_type = input_array_a.data_type;
*tail_it = model->operators.emplace(*tail_it, slice_a_op) + 1;
// Reshape to remove the first dimension ([1,M,N] -> [M,N]).
auto* slice_a_reshape_op = new TensorFlowReshapeOperator;
slice_a_reshape_op->inputs = {
slice_a_op->outputs[0],
CreateInt32Array(model, batch_name + "/slice_a/reshape/shape",
{-1, input_array_a.shape().dims(dims_count - 1)})};
slice_a_reshape_op->outputs = {
AvailableArrayName(*model, batch_name + "/slice_a/reshape")};
auto& slice_a_reshape_op_output =
model->GetOrCreateArray(slice_a_reshape_op->outputs[0]);
slice_a_reshape_op_output.data_type = input_array_a.data_type;
*tail_it = model->operators.emplace(*tail_it, slice_a_reshape_op) + 1;
// tf.slice(b, ...).
std::vector<int> begin_indices_b = batch;
begin_indices_b.resize(dims_count);
std::vector<int> slice_size_b = input_array_b.shape().dims();
for (int i = 0; i < batch.size(); ++i) {
slice_size_b[i] = 1;
}
auto* slice_b_op = new SliceOperator;
slice_b_op->inputs = {
input_rhs,
CreateInt32Array(model, batch_name + "/slice_b/slice/begin",
begin_indices_b),
CreateInt32Array(model, batch_name + "/slice_b/slice/size", slice_size_b),
};
slice_b_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_b")};
auto& slice_b_op_output = model->GetOrCreateArray(slice_b_op->outputs[0]);
slice_b_op_output.data_type = input_array_b.data_type;
*tail_it = model->operators.emplace(*tail_it, slice_b_op) + 1;
// Reshape to remove the first dimension ([1,M,N] -> [M,N]).
auto* slice_b_reshape_op = new TensorFlowReshapeOperator;
slice_b_reshape_op->inputs = {
slice_b_op->outputs[0],
CreateInt32Array(model, batch_name + "/slice_b/reshape/shape",
{-1, input_array_b.shape().dims(dims_count - 1)})};
slice_b_reshape_op->outputs = {
AvailableArrayName(*model, batch_name + "/slice_b/reshape")};
auto& slice_b_reshape_op_output =
model->GetOrCreateArray(slice_b_reshape_op->outputs[0]);
slice_b_reshape_op_output.data_type = input_array_b.data_type;
*tail_it = model->operators.emplace(*tail_it, slice_b_reshape_op) + 1;
// tf.matmul(slice_a, slice_b).
auto* matmul_op = new TensorFlowMatMulOperator;
matmul_op->inputs = {slice_a_reshape_op->outputs[0],
slice_b_reshape_op->outputs[0]};
matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
matmul_op_output.data_type = input_array_a.data_type;
*tail_it = model->operators.emplace(*tail_it, matmul_op) + 1;
// Add to stack.
pack_inputs->push_back(matmul_op->outputs[0]);
} }
std::vector<string> UnrollBatchMatMulRecursion( std::vector<string> SliceInput(
const string& input_lhs, const string& input_rhs, const string& input, const string& base_name, const string& input_name,
const BatchMatMulOperator* batch_op, Model* model, const int batch_size, const Array& input_array, Model* model,
std::vector<std::unique_ptr<Operator>>::iterator* tail_it, std::vector<std::unique_ptr<Operator>>::iterator* tail_it) {
const std::vector<int>& batch_prefix) { int rank = input_array.shape().dimensions_count();
const auto& input_array_a = model->GetArray(input_lhs); int num_rows = input_array.shape().dims(rank - 2);
const auto& dims_vec = input_array_a.shape().dims(); int num_cols = input_array.shape().dims(rank - 1);
const int current_dim_size = dims_vec[batch_prefix.size()]; // Reshape to rank-3 Tensor with first dimension as the batch size.
std::vector<string> batch_pack_inputs; auto* reshape_op = new TensorFlowReshapeOperator;
reshape_op->inputs = {
input,
CreateInt32Array(model, absl::StrCat(base_name, "/reshape_a/shape"),
{batch_size, num_rows, num_cols})};
reshape_op->outputs = {AvailableArrayName(
*model, absl::StrCat(base_name, "/reshape_", input_name, "/reshape"))};
auto& reshape_op_output = model->GetOrCreateArray(reshape_op->outputs[0]);
reshape_op_output.data_type = input_array.data_type;
*tail_it = model->operators.emplace(*tail_it, reshape_op) + 1;
if (batch_prefix.size() + 3 == dims_vec.size()) { // Slice along each batch index and remember the slice output for future use.
// Base case std::vector<string> slice_outputs;
for (int batch = 0; batch < current_dim_size; ++batch) { for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
std::vector<int> new_batch_prefix = batch_prefix; std::string batch_name =
new_batch_prefix.emplace_back(batch); absl::StrCat(base_name, "_b", batch_idx, "/slice_", input_name);
UnrollBatchMatMul3D(input_lhs, input_rhs, batch_op, new_batch_prefix, auto* slice_op = new SliceOperator;
model, tail_it, &batch_pack_inputs); slice_op->inputs = {
} reshape_op->outputs[0],
} else { CreateInt32Array(model, absl::StrCat(batch_name, "/slice/begin"),
// Recursion {batch_idx, 0, 0}),
for (int batch = 0; batch < current_dim_size; ++batch) { CreateInt32Array(model, absl::StrCat(batch_name, "/slice/size"),
std::vector<int> new_batch_prefix = batch_prefix; {1, num_rows, num_cols})};
new_batch_prefix.emplace_back(batch); slice_op->outputs = {
std::vector<string> pack_inputs = UnrollBatchMatMulRecursion( AvailableArrayName(*model, absl::StrCat(batch_name, "/slice"))};
input_lhs, input_rhs, batch_op, model, tail_it, new_batch_prefix); auto& slice_op_output = model->GetOrCreateArray(slice_op->outputs[0]);
slice_op_output.data_type = input_array.data_type;
*tail_it = model->operators.emplace(*tail_it, slice_op) + 1;
// The pack that will join all the individual matmul results together. // Reshape to rank-2: [1, num_rows, num_cols] -> [num_rows, num_cols].
auto* pack_op = new PackOperator; auto* slice_reshape_op = new TensorFlowReshapeOperator;
std::string batch_name = absl::StrCat( slice_reshape_op->inputs = {
batch_op->outputs[0], "_b", absl::StrJoin(new_batch_prefix, "-")); slice_op->outputs[0],
pack_op->inputs = pack_inputs; CreateInt32Array(model, absl::StrCat(batch_name, "/reshape/shape"),
pack_op->outputs = {AvailableArrayName(*model, batch_name + "/pack")}; {num_rows, num_cols})};
auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]); slice_reshape_op->outputs = {
pack_op_output.data_type = input_array_a.data_type; AvailableArrayName(*model, absl::StrCat(batch_name, "/reshape"))};
pack_op->axis = 0; auto& slice_reshape_op_output =
pack_op->values_count = pack_inputs.size(); model->GetOrCreateArray(slice_reshape_op->outputs[0]);
*tail_it = model->operators.emplace(*tail_it, pack_op) + 1; slice_reshape_op_output.data_type = input_array.data_type;
*tail_it = model->operators.emplace(*tail_it, slice_reshape_op) + 1;
batch_pack_inputs.push_back(pack_op->outputs[0]); slice_outputs.push_back(slice_reshape_op->outputs[0]);
}
} }
return batch_pack_inputs; return slice_outputs;
} }
std::vector<int32> GetTransposePerm(const Array& input_array) { std::vector<int32> GetTransposePerm(const Array& input_array) {
@ -202,15 +130,6 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
// Unrolls a BatchMatMul on the batch dimension. // Unrolls a BatchMatMul on the batch dimension.
// We need to slice each batch out of the inputs, matmul them individually, then // We need to slice each batch out of the inputs, matmul them individually, then
// stack them all back together at the end. // stack them all back together at the end.
//
// This transform effectively looks like:
// result_slices = []
// for bat in B:
// slice_a = tf.reshape(tf.slice(a, [bat, 0, 0], [1, M, N]), [M, N])
// slice_b = tf.reshape(tf.slice(b, [bat, 0, 0], [1, M, N]), [M, N])
// slice_c = tf.matmul(slice_a, slice_b)
// result_slices[bat] = slice_c
// result = tf.stack(result_slices)
::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index, ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
bool* modified) { bool* modified) {
*modified = false; *modified = false;
@ -220,7 +139,6 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
} }
const auto* batch_op = const auto* batch_op =
static_cast<const BatchMatMulOperator*>(batch_op_it->get()); static_cast<const BatchMatMulOperator*>(batch_op_it->get());
auto& tail_it = batch_op_it; auto& tail_it = batch_op_it;
string input_lhs = batch_op->inputs[0]; string input_lhs = batch_op->inputs[0];
@ -246,20 +164,25 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
} }
const auto& input_array_b = model->GetArray(input_rhs); const auto& input_array_b = model->GetArray(input_rhs);
const int dims = input_array_a.shape().dimensions_count(); // Ensure that input ranks are at least 2 and batch shapes are broadcastable.
for (int i = 0; i < dims - 2; ++i) { const int dims_a = input_array_a.shape().dimensions_count();
CHECK_EQ(input_array_a.shape().dims(i), input_array_b.shape().dims(i)) const int dims_b = input_array_b.shape().dimensions_count();
<< "input array not consistent at index " << i; CHECK_GE(dims_a, 2) << "First input must have rank >= 2";
} CHECK_GE(dims_b, 2) << "Second input must have rank >= 2";
CHECK_EQ(input_array_a.shape().dims(dims - 1),
input_array_b.shape().dims(dims - 2)) ::tensorflow::MatMulBCast bcast(
ToInlinedVector(input_array_a.shape().dims()),
ToInlinedVector(input_array_b.shape().dims()));
CHECK(bcast.IsValid()) << "Input batch dimensions must be broadcastable";
CHECK_EQ(input_array_a.shape().dims(dims_a - 1),
input_array_b.shape().dims(dims_b - 2))
<< "Input dimensions must be compatible for multipication. shape a = [" << "Input dimensions must be compatible for multipication. shape a = ["
<< absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = [" << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = ["
<< absl::StrJoin(input_array_b.shape().dims(), ", ") << "]"; << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]";
if (dims == 2) { if (dims_a == 2 && dims_b == 2) {
// This is really just a MatMul. This likely means that someone hand-crafted // This is really just a MatMul.
// a graphdef with a BatchMatMul when they really wanted a MatMul.
AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator", AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator",
LogName(*batch_op)); LogName(*batch_op));
auto* matmul_op = new TensorFlowMatMulOperator; auto* matmul_op = new TensorFlowMatMulOperator;
@ -271,23 +194,65 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }
CHECK_GE(input_array_a.shape().dimensions_count(), 3)
<< "Input arrays must have rank >= 3";
const auto& dims_vec = input_array_a.shape().dims();
AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op), AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
std::accumulate(dims_vec.begin(), dims_vec.end() - 2, 1, bcast.output_batch_size());
std::multiplies<int>())); string base_name = std::string(batch_op->outputs[0]);
std::vector<string> pack_inputs = UnrollBatchMatMulRecursion( // Compute slices for each batch in the LHS and RHS.
input_lhs, input_rhs, batch_op, model, &tail_it, {}); std::vector<string> slice_a_outputs =
SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a,
model, &tail_it);
std::vector<string> slice_b_outputs =
SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b,
model, &tail_it);
// Compute (single batch) MatMul for each output batch. The MatMul outputs are
// then packed together into one output Tensor.
std::vector<string> pack_inputs;
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
std::string batch_name =
absl::StrCat(batch_op->outputs[0], "_b", batch_idx);
const int a_batch_idx = bcast.IsBroadcastingRequired()
? bcast.x_batch_indices()[batch_idx]
: batch_idx;
const int b_batch_idx = bcast.IsBroadcastingRequired()
? bcast.y_batch_indices()[batch_idx]
: batch_idx;
auto* matmul_op = new TensorFlowMatMulOperator;
matmul_op->inputs = {slice_a_outputs[a_batch_idx],
slice_b_outputs[b_batch_idx]};
matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
matmul_op_output.data_type = input_array_a.data_type;
tail_it = model->operators.emplace(tail_it, matmul_op) + 1;
// Add to stack.
pack_inputs.push_back(matmul_op->outputs[0]);
}
// Combine the result of each individual MatMul into a rank-3 Tensor.
auto* pack_op = new PackOperator; auto* pack_op = new PackOperator;
pack_op->inputs = pack_inputs; pack_op->inputs = pack_inputs;
pack_op->outputs = {batch_op->outputs[0]}; pack_op->outputs = {AvailableArrayName(*model, base_name + "/pack")};
auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]);
pack_op_output.data_type = input_array_a.data_type;
pack_op->axis = 0; pack_op->axis = 0;
pack_op->values_count = pack_inputs.size(); pack_op->values_count = pack_inputs.size();
model->operators.emplace(tail_it, pack_op); tail_it = model->operators.emplace(tail_it, pack_op) + 1;
// Reshape the rank-3 Tensor into the correct output shape.
const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
std::vector<int> result_shape(result_batch_shape.begin(),
result_batch_shape.end());
result_shape.push_back(input_array_a.shape().dims(dims_a - 2));
result_shape.push_back(input_array_b.shape().dims(dims_b - 1));
auto* reshape_result_op = new TensorFlowReshapeOperator;
reshape_result_op->inputs = {
pack_op->outputs[0],
CreateInt32Array(model, base_name + "/reshape_out/shape", result_shape)};
reshape_result_op->outputs = {batch_op->outputs[0]};
model->operators.emplace(tail_it, reshape_result_op);
// Remove the old batch matmul now that we've unrolled. // Remove the old batch matmul now that we've unrolled.
batch_op_it = model->operators.begin(); batch_op_it = model->operators.begin();

View File

@ -971,9 +971,8 @@ struct TensorFlowIdentityOperator : Operator {
TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {} TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {}
}; };
// Batch matrix multiplication operator. This comes from the (deprecated) // Batch matrix multiplication operator. This comes from a tf.matmul where one
// tf.batch_matmul or a tf.matmul that has rank 3. dims(0) is the batch count // of the operands has rank 3 or more.
// and it can be trivially unrolled into a series of matmuls on each element.
// //
// Inputs: // Inputs:
// inputs[0]: required: the left-hand side matrix // inputs[0]: required: the left-hand side matrix