Add support for full BatchMatMulV2 functionality (broadcasting) in TFLite.
PiperOrigin-RevId: 245290978
This commit is contained in:
parent
87b5f63ba7
commit
17db87c632
@ -76,16 +76,6 @@ KNOWN_BUGS = {
|
|||||||
r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
|
r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
|
||||||
# Div will use floordiv.
|
# Div will use floordiv.
|
||||||
r"div.*int32": "72051395",
|
r"div.*int32": "72051395",
|
||||||
|
|
||||||
# TFLite/Toco does not support BatchMatMul(V2) broadcasting semantic yet.
|
|
||||||
# Simple broadcast.
|
|
||||||
r"unroll_batch_matmul.*shape=\[\(1,2,3\),\(3,5\).*": "130887526",
|
|
||||||
# Empty batch broadcast.
|
|
||||||
r"unroll_batch_matmul.*shape=\[\(2,5,3\),\(3,7\).*": "130887526",
|
|
||||||
# Single batch with non-empty batch broadcast.
|
|
||||||
r"unroll_batch_matmul.*shape=\[\(1,5,3\),\(4,3,7\).*": "130887526",
|
|
||||||
# Broadcast both operands
|
|
||||||
r"unroll_batch_matmul.*shape=\[\(3,1,5,3\),\(1,4,3,7\).*": "130887526",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -283,9 +283,11 @@ cc_library(
|
|||||||
":runtime",
|
":runtime",
|
||||||
":toco_port",
|
":toco_port",
|
||||||
":tooling_util",
|
":tooling_util",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/lite/kernels/internal:quantization_util",
|
"//tensorflow/lite/kernels/internal:quantization_util",
|
||||||
"//tensorflow/lite/kernels/internal:strided_slice_logic",
|
"//tensorflow/lite/kernels/internal:strided_slice_logic",
|
||||||
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -18,147 +18,75 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/matmul_bcast.h"
|
||||||
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
|
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
|
||||||
#include "tensorflow/lite/toco/model.h"
|
#include "tensorflow/lite/toco/model.h"
|
||||||
#include "tensorflow/lite/toco/tooling_util.h"
|
#include "tensorflow/lite/toco/tooling_util.h"
|
||||||
|
|
||||||
namespace toco {
|
namespace toco {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void UnrollBatchMatMul3D(
|
absl::InlinedVector<int64, 4> ToInlinedVector(const std::vector<int>& vec) {
|
||||||
const string& input_lhs, const string& input_rhs,
|
return absl::InlinedVector<int64, 4>(vec.begin(), vec.end());
|
||||||
const BatchMatMulOperator* batch_op, const std::vector<int> batch,
|
|
||||||
Model* model, std::vector<std::unique_ptr<Operator>>::iterator* tail_it,
|
|
||||||
std::vector<string>* pack_inputs) {
|
|
||||||
const std::string batch_name =
|
|
||||||
absl::StrCat(batch_op->outputs[0], "_b", absl::StrJoin(batch, "-"));
|
|
||||||
const auto& input_array_a = model->GetArray(input_lhs);
|
|
||||||
const auto& input_array_b = model->GetArray(input_rhs);
|
|
||||||
const int dims_count = input_array_a.shape().dimensions_count();
|
|
||||||
|
|
||||||
// tf.slice(a, ...).
|
|
||||||
std::vector<int> begin_indices_a = batch;
|
|
||||||
begin_indices_a.resize(dims_count);
|
|
||||||
std::vector<int> slice_size_a = input_array_a.shape().dims();
|
|
||||||
for (int i = 0; i < batch.size(); ++i) {
|
|
||||||
slice_size_a[i] = 1;
|
|
||||||
}
|
|
||||||
auto* slice_a_op = new SliceOperator;
|
|
||||||
slice_a_op->inputs = {
|
|
||||||
input_lhs,
|
|
||||||
CreateInt32Array(model, batch_name + "/slice_a/slice/begin",
|
|
||||||
begin_indices_a),
|
|
||||||
CreateInt32Array(model, batch_name + "/slice_a/slice/size", slice_size_a),
|
|
||||||
};
|
|
||||||
slice_a_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_a")};
|
|
||||||
auto& slice_a_op_output = model->GetOrCreateArray(slice_a_op->outputs[0]);
|
|
||||||
slice_a_op_output.data_type = input_array_a.data_type;
|
|
||||||
*tail_it = model->operators.emplace(*tail_it, slice_a_op) + 1;
|
|
||||||
|
|
||||||
// Reshape to remove the first dimension ([1,M,N] -> [M,N]).
|
|
||||||
auto* slice_a_reshape_op = new TensorFlowReshapeOperator;
|
|
||||||
slice_a_reshape_op->inputs = {
|
|
||||||
slice_a_op->outputs[0],
|
|
||||||
CreateInt32Array(model, batch_name + "/slice_a/reshape/shape",
|
|
||||||
{-1, input_array_a.shape().dims(dims_count - 1)})};
|
|
||||||
slice_a_reshape_op->outputs = {
|
|
||||||
AvailableArrayName(*model, batch_name + "/slice_a/reshape")};
|
|
||||||
auto& slice_a_reshape_op_output =
|
|
||||||
model->GetOrCreateArray(slice_a_reshape_op->outputs[0]);
|
|
||||||
slice_a_reshape_op_output.data_type = input_array_a.data_type;
|
|
||||||
*tail_it = model->operators.emplace(*tail_it, slice_a_reshape_op) + 1;
|
|
||||||
|
|
||||||
// tf.slice(b, ...).
|
|
||||||
std::vector<int> begin_indices_b = batch;
|
|
||||||
begin_indices_b.resize(dims_count);
|
|
||||||
std::vector<int> slice_size_b = input_array_b.shape().dims();
|
|
||||||
for (int i = 0; i < batch.size(); ++i) {
|
|
||||||
slice_size_b[i] = 1;
|
|
||||||
}
|
|
||||||
auto* slice_b_op = new SliceOperator;
|
|
||||||
slice_b_op->inputs = {
|
|
||||||
input_rhs,
|
|
||||||
CreateInt32Array(model, batch_name + "/slice_b/slice/begin",
|
|
||||||
begin_indices_b),
|
|
||||||
CreateInt32Array(model, batch_name + "/slice_b/slice/size", slice_size_b),
|
|
||||||
};
|
|
||||||
slice_b_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_b")};
|
|
||||||
auto& slice_b_op_output = model->GetOrCreateArray(slice_b_op->outputs[0]);
|
|
||||||
slice_b_op_output.data_type = input_array_b.data_type;
|
|
||||||
*tail_it = model->operators.emplace(*tail_it, slice_b_op) + 1;
|
|
||||||
|
|
||||||
// Reshape to remove the first dimension ([1,M,N] -> [M,N]).
|
|
||||||
auto* slice_b_reshape_op = new TensorFlowReshapeOperator;
|
|
||||||
slice_b_reshape_op->inputs = {
|
|
||||||
slice_b_op->outputs[0],
|
|
||||||
CreateInt32Array(model, batch_name + "/slice_b/reshape/shape",
|
|
||||||
{-1, input_array_b.shape().dims(dims_count - 1)})};
|
|
||||||
slice_b_reshape_op->outputs = {
|
|
||||||
AvailableArrayName(*model, batch_name + "/slice_b/reshape")};
|
|
||||||
auto& slice_b_reshape_op_output =
|
|
||||||
model->GetOrCreateArray(slice_b_reshape_op->outputs[0]);
|
|
||||||
slice_b_reshape_op_output.data_type = input_array_b.data_type;
|
|
||||||
*tail_it = model->operators.emplace(*tail_it, slice_b_reshape_op) + 1;
|
|
||||||
|
|
||||||
// tf.matmul(slice_a, slice_b).
|
|
||||||
auto* matmul_op = new TensorFlowMatMulOperator;
|
|
||||||
matmul_op->inputs = {slice_a_reshape_op->outputs[0],
|
|
||||||
slice_b_reshape_op->outputs[0]};
|
|
||||||
matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
|
|
||||||
auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
|
|
||||||
matmul_op_output.data_type = input_array_a.data_type;
|
|
||||||
*tail_it = model->operators.emplace(*tail_it, matmul_op) + 1;
|
|
||||||
|
|
||||||
// Add to stack.
|
|
||||||
pack_inputs->push_back(matmul_op->outputs[0]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> UnrollBatchMatMulRecursion(
|
std::vector<string> SliceInput(
|
||||||
const string& input_lhs, const string& input_rhs,
|
const string& input, const string& base_name, const string& input_name,
|
||||||
const BatchMatMulOperator* batch_op, Model* model,
|
const int batch_size, const Array& input_array, Model* model,
|
||||||
std::vector<std::unique_ptr<Operator>>::iterator* tail_it,
|
std::vector<std::unique_ptr<Operator>>::iterator* tail_it) {
|
||||||
const std::vector<int>& batch_prefix) {
|
int rank = input_array.shape().dimensions_count();
|
||||||
const auto& input_array_a = model->GetArray(input_lhs);
|
int num_rows = input_array.shape().dims(rank - 2);
|
||||||
const auto& dims_vec = input_array_a.shape().dims();
|
int num_cols = input_array.shape().dims(rank - 1);
|
||||||
const int current_dim_size = dims_vec[batch_prefix.size()];
|
// Reshape to rank-3 Tensor with first dimension as the batch size.
|
||||||
std::vector<string> batch_pack_inputs;
|
auto* reshape_op = new TensorFlowReshapeOperator;
|
||||||
|
reshape_op->inputs = {
|
||||||
|
input,
|
||||||
|
CreateInt32Array(model, absl::StrCat(base_name, "/reshape_a/shape"),
|
||||||
|
{batch_size, num_rows, num_cols})};
|
||||||
|
reshape_op->outputs = {AvailableArrayName(
|
||||||
|
*model, absl::StrCat(base_name, "/reshape_", input_name, "/reshape"))};
|
||||||
|
auto& reshape_op_output = model->GetOrCreateArray(reshape_op->outputs[0]);
|
||||||
|
reshape_op_output.data_type = input_array.data_type;
|
||||||
|
*tail_it = model->operators.emplace(*tail_it, reshape_op) + 1;
|
||||||
|
|
||||||
if (batch_prefix.size() + 3 == dims_vec.size()) {
|
// Slice along each batch index and remember the slice output for future use.
|
||||||
// Base case
|
std::vector<string> slice_outputs;
|
||||||
for (int batch = 0; batch < current_dim_size; ++batch) {
|
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||||
std::vector<int> new_batch_prefix = batch_prefix;
|
std::string batch_name =
|
||||||
new_batch_prefix.emplace_back(batch);
|
absl::StrCat(base_name, "_b", batch_idx, "/slice_", input_name);
|
||||||
UnrollBatchMatMul3D(input_lhs, input_rhs, batch_op, new_batch_prefix,
|
auto* slice_op = new SliceOperator;
|
||||||
model, tail_it, &batch_pack_inputs);
|
slice_op->inputs = {
|
||||||
}
|
reshape_op->outputs[0],
|
||||||
} else {
|
CreateInt32Array(model, absl::StrCat(batch_name, "/slice/begin"),
|
||||||
// Recursion
|
{batch_idx, 0, 0}),
|
||||||
for (int batch = 0; batch < current_dim_size; ++batch) {
|
CreateInt32Array(model, absl::StrCat(batch_name, "/slice/size"),
|
||||||
std::vector<int> new_batch_prefix = batch_prefix;
|
{1, num_rows, num_cols})};
|
||||||
new_batch_prefix.emplace_back(batch);
|
slice_op->outputs = {
|
||||||
std::vector<string> pack_inputs = UnrollBatchMatMulRecursion(
|
AvailableArrayName(*model, absl::StrCat(batch_name, "/slice"))};
|
||||||
input_lhs, input_rhs, batch_op, model, tail_it, new_batch_prefix);
|
auto& slice_op_output = model->GetOrCreateArray(slice_op->outputs[0]);
|
||||||
|
slice_op_output.data_type = input_array.data_type;
|
||||||
|
*tail_it = model->operators.emplace(*tail_it, slice_op) + 1;
|
||||||
|
|
||||||
// The pack that will join all the individual matmul results together.
|
// Reshape to rank-2: [1, num_rows, num_cols] -> [num_rows, num_cols].
|
||||||
auto* pack_op = new PackOperator;
|
auto* slice_reshape_op = new TensorFlowReshapeOperator;
|
||||||
std::string batch_name = absl::StrCat(
|
slice_reshape_op->inputs = {
|
||||||
batch_op->outputs[0], "_b", absl::StrJoin(new_batch_prefix, "-"));
|
slice_op->outputs[0],
|
||||||
pack_op->inputs = pack_inputs;
|
CreateInt32Array(model, absl::StrCat(batch_name, "/reshape/shape"),
|
||||||
pack_op->outputs = {AvailableArrayName(*model, batch_name + "/pack")};
|
{num_rows, num_cols})};
|
||||||
auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]);
|
slice_reshape_op->outputs = {
|
||||||
pack_op_output.data_type = input_array_a.data_type;
|
AvailableArrayName(*model, absl::StrCat(batch_name, "/reshape"))};
|
||||||
pack_op->axis = 0;
|
auto& slice_reshape_op_output =
|
||||||
pack_op->values_count = pack_inputs.size();
|
model->GetOrCreateArray(slice_reshape_op->outputs[0]);
|
||||||
*tail_it = model->operators.emplace(*tail_it, pack_op) + 1;
|
slice_reshape_op_output.data_type = input_array.data_type;
|
||||||
|
*tail_it = model->operators.emplace(*tail_it, slice_reshape_op) + 1;
|
||||||
|
|
||||||
batch_pack_inputs.push_back(pack_op->outputs[0]);
|
slice_outputs.push_back(slice_reshape_op->outputs[0]);
|
||||||
}
|
}
|
||||||
}
|
return slice_outputs;
|
||||||
return batch_pack_inputs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int32> GetTransposePerm(const Array& input_array) {
|
std::vector<int32> GetTransposePerm(const Array& input_array) {
|
||||||
@ -202,15 +130,6 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
|
|||||||
// Unrolls a BatchMatMul on the batch dimension.
|
// Unrolls a BatchMatMul on the batch dimension.
|
||||||
// We need to slice each batch out of the inputs, matmul them individually, then
|
// We need to slice each batch out of the inputs, matmul them individually, then
|
||||||
// stack them all back together at the end.
|
// stack them all back together at the end.
|
||||||
//
|
|
||||||
// This transform effectively looks like:
|
|
||||||
// result_slices = []
|
|
||||||
// for bat in B:
|
|
||||||
// slice_a = tf.reshape(tf.slice(a, [bat, 0, 0], [1, M, N]), [M, N])
|
|
||||||
// slice_b = tf.reshape(tf.slice(b, [bat, 0, 0], [1, M, N]), [M, N])
|
|
||||||
// slice_c = tf.matmul(slice_a, slice_b)
|
|
||||||
// result_slices[bat] = slice_c
|
|
||||||
// result = tf.stack(result_slices)
|
|
||||||
::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
|
::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
|
||||||
bool* modified) {
|
bool* modified) {
|
||||||
*modified = false;
|
*modified = false;
|
||||||
@ -220,7 +139,6 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
|
|||||||
}
|
}
|
||||||
const auto* batch_op =
|
const auto* batch_op =
|
||||||
static_cast<const BatchMatMulOperator*>(batch_op_it->get());
|
static_cast<const BatchMatMulOperator*>(batch_op_it->get());
|
||||||
|
|
||||||
auto& tail_it = batch_op_it;
|
auto& tail_it = batch_op_it;
|
||||||
|
|
||||||
string input_lhs = batch_op->inputs[0];
|
string input_lhs = batch_op->inputs[0];
|
||||||
@ -246,20 +164,25 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
|
|||||||
}
|
}
|
||||||
const auto& input_array_b = model->GetArray(input_rhs);
|
const auto& input_array_b = model->GetArray(input_rhs);
|
||||||
|
|
||||||
const int dims = input_array_a.shape().dimensions_count();
|
// Ensure that input ranks are at least 2 and batch shapes are broadcastable.
|
||||||
for (int i = 0; i < dims - 2; ++i) {
|
const int dims_a = input_array_a.shape().dimensions_count();
|
||||||
CHECK_EQ(input_array_a.shape().dims(i), input_array_b.shape().dims(i))
|
const int dims_b = input_array_b.shape().dimensions_count();
|
||||||
<< "input array not consistent at index " << i;
|
CHECK_GE(dims_a, 2) << "First input must have rank >= 2";
|
||||||
}
|
CHECK_GE(dims_b, 2) << "Second input must have rank >= 2";
|
||||||
CHECK_EQ(input_array_a.shape().dims(dims - 1),
|
|
||||||
input_array_b.shape().dims(dims - 2))
|
::tensorflow::MatMulBCast bcast(
|
||||||
|
ToInlinedVector(input_array_a.shape().dims()),
|
||||||
|
ToInlinedVector(input_array_b.shape().dims()));
|
||||||
|
CHECK(bcast.IsValid()) << "Input batch dimensions must be broadcastable";
|
||||||
|
|
||||||
|
CHECK_EQ(input_array_a.shape().dims(dims_a - 1),
|
||||||
|
input_array_b.shape().dims(dims_b - 2))
|
||||||
<< "Input dimensions must be compatible for multipication. shape a = ["
|
<< "Input dimensions must be compatible for multipication. shape a = ["
|
||||||
<< absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = ["
|
<< absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = ["
|
||||||
<< absl::StrJoin(input_array_b.shape().dims(), ", ") << "]";
|
<< absl::StrJoin(input_array_b.shape().dims(), ", ") << "]";
|
||||||
|
|
||||||
if (dims == 2) {
|
if (dims_a == 2 && dims_b == 2) {
|
||||||
// This is really just a MatMul. This likely means that someone hand-crafted
|
// This is really just a MatMul.
|
||||||
// a graphdef with a BatchMatMul when they really wanted a MatMul.
|
|
||||||
AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator",
|
AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator",
|
||||||
LogName(*batch_op));
|
LogName(*batch_op));
|
||||||
auto* matmul_op = new TensorFlowMatMulOperator;
|
auto* matmul_op = new TensorFlowMatMulOperator;
|
||||||
@ -271,23 +194,65 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
|
|||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK_GE(input_array_a.shape().dimensions_count(), 3)
|
|
||||||
<< "Input arrays must have rank >= 3";
|
|
||||||
|
|
||||||
const auto& dims_vec = input_array_a.shape().dims();
|
|
||||||
AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
|
AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
|
||||||
std::accumulate(dims_vec.begin(), dims_vec.end() - 2, 1,
|
bcast.output_batch_size());
|
||||||
std::multiplies<int>()));
|
string base_name = std::string(batch_op->outputs[0]);
|
||||||
|
|
||||||
std::vector<string> pack_inputs = UnrollBatchMatMulRecursion(
|
// Compute slices for each batch in the LHS and RHS.
|
||||||
input_lhs, input_rhs, batch_op, model, &tail_it, {});
|
std::vector<string> slice_a_outputs =
|
||||||
|
SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a,
|
||||||
|
model, &tail_it);
|
||||||
|
std::vector<string> slice_b_outputs =
|
||||||
|
SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b,
|
||||||
|
model, &tail_it);
|
||||||
|
|
||||||
|
// Compute (single batch) MatMul for each output batch. The MatMul outputs are
|
||||||
|
// then packed together into one output Tensor.
|
||||||
|
std::vector<string> pack_inputs;
|
||||||
|
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
|
||||||
|
std::string batch_name =
|
||||||
|
absl::StrCat(batch_op->outputs[0], "_b", batch_idx);
|
||||||
|
const int a_batch_idx = bcast.IsBroadcastingRequired()
|
||||||
|
? bcast.x_batch_indices()[batch_idx]
|
||||||
|
: batch_idx;
|
||||||
|
const int b_batch_idx = bcast.IsBroadcastingRequired()
|
||||||
|
? bcast.y_batch_indices()[batch_idx]
|
||||||
|
: batch_idx;
|
||||||
|
auto* matmul_op = new TensorFlowMatMulOperator;
|
||||||
|
matmul_op->inputs = {slice_a_outputs[a_batch_idx],
|
||||||
|
slice_b_outputs[b_batch_idx]};
|
||||||
|
matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
|
||||||
|
auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
|
||||||
|
matmul_op_output.data_type = input_array_a.data_type;
|
||||||
|
tail_it = model->operators.emplace(tail_it, matmul_op) + 1;
|
||||||
|
|
||||||
|
// Add to stack.
|
||||||
|
pack_inputs.push_back(matmul_op->outputs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine the result of each individual MatMul into a rank-3 Tensor.
|
||||||
auto* pack_op = new PackOperator;
|
auto* pack_op = new PackOperator;
|
||||||
pack_op->inputs = pack_inputs;
|
pack_op->inputs = pack_inputs;
|
||||||
pack_op->outputs = {batch_op->outputs[0]};
|
pack_op->outputs = {AvailableArrayName(*model, base_name + "/pack")};
|
||||||
|
auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]);
|
||||||
|
pack_op_output.data_type = input_array_a.data_type;
|
||||||
pack_op->axis = 0;
|
pack_op->axis = 0;
|
||||||
pack_op->values_count = pack_inputs.size();
|
pack_op->values_count = pack_inputs.size();
|
||||||
model->operators.emplace(tail_it, pack_op);
|
tail_it = model->operators.emplace(tail_it, pack_op) + 1;
|
||||||
|
|
||||||
|
// Reshape the rank-3 Tensor into the correct output shape.
|
||||||
|
const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
|
||||||
|
std::vector<int> result_shape(result_batch_shape.begin(),
|
||||||
|
result_batch_shape.end());
|
||||||
|
result_shape.push_back(input_array_a.shape().dims(dims_a - 2));
|
||||||
|
result_shape.push_back(input_array_b.shape().dims(dims_b - 1));
|
||||||
|
|
||||||
|
auto* reshape_result_op = new TensorFlowReshapeOperator;
|
||||||
|
reshape_result_op->inputs = {
|
||||||
|
pack_op->outputs[0],
|
||||||
|
CreateInt32Array(model, base_name + "/reshape_out/shape", result_shape)};
|
||||||
|
reshape_result_op->outputs = {batch_op->outputs[0]};
|
||||||
|
model->operators.emplace(tail_it, reshape_result_op);
|
||||||
|
|
||||||
// Remove the old batch matmul now that we've unrolled.
|
// Remove the old batch matmul now that we've unrolled.
|
||||||
batch_op_it = model->operators.begin();
|
batch_op_it = model->operators.begin();
|
||||||
|
@ -971,9 +971,8 @@ struct TensorFlowIdentityOperator : Operator {
|
|||||||
TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {}
|
TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Batch matrix multiplication operator. This comes from the (deprecated)
|
// Batch matrix multiplication operator. This comes from a tf.matmul where one
|
||||||
// tf.batch_matmul or a tf.matmul that has rank 3. dims(0) is the batch count
|
// of the operands has rank 3 or more.
|
||||||
// and it can be trivially unrolled into a series of matmuls on each element.
|
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
// inputs[0]: required: the left-hand side matrix
|
// inputs[0]: required: the left-hand side matrix
|
||||||
|
Loading…
Reference in New Issue
Block a user