TF BatchMatMul/V2 -> TFL BatchMatMul when batch size is unknown.
When batch size is known, TF BatchMatMul gets unfolded to TFL FullyConnected + additional ops. In the case of unknown batch size before this CL, conversion would fail (or would require Flex delegate). This transformation is currently for float only since TFL BatchMatMul built-in Op is float only. PiperOrigin-RevId: 308632958 Change-Id: Icf2fe2192a4f5463f8795481f7ce613d10c984d9
This commit is contained in:
parent
127aa2a6c0
commit
f761369203
@ -842,40 +842,6 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [
|
||||
NoSideEffect,
|
||||
TFL_OperandHasAtleastRank<0, 2>,
|
||||
TFL_OperandHasAtleastRank<1, 2>,
|
||||
SameOperandsAndResultElementType]> {
|
||||
|
||||
let summary = "Batch Matrix Multiply Operator";
|
||||
|
||||
let description = [{
|
||||
Performs a batched matrix multiplication on the inputs. Follows the
|
||||
conventions of TensorFlow BatchMatMulV2, with support for unknown dimensions
|
||||
in the batch dimensions and broadcasting.
|
||||
|
||||
Inputs:
|
||||
`inputs[0]`: required: input LHS
|
||||
`inputs[1]`: required: input RHS
|
||||
`adjoint_lhs`: optional: Transpose LHS (default false)
|
||||
`adjoint_lhs`: optional: Transpose LHS (default false)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32]>:$lhs,
|
||||
TFL_TensorOf<[F32]>:$rhs,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adjoint_lhs,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adjoint_rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_GatherOp : TFL_Op<"gather", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
|
@ -1497,28 +1497,3 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
|
||||
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
|
||||
return %0 : tensor<10x17xf32>
|
||||
// CHECK-LABEL: matmul_batch
|
||||
// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adjoint_lhs = false, adjoint_rhs = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
|
||||
}
|
||||
|
||||
func @matmul_batchv2(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<2x10x17xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
|
||||
(tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32>
|
||||
return %0 : tensor<2x10x17xf32>
|
||||
// CHECK-LABEL: matmul_batchv2
|
||||
// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adjoint_lhs = false, adjoint_rhs = false} : (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32>
|
||||
}
|
||||
|
||||
func @matmul_batchv2_unknown_dim(%arg0: tensor<?x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<?x10x17xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
|
||||
(tensor<?x10x15xf32>, tensor<15x17xf32>) -> tensor<?x10x17xf32>
|
||||
return %0 : tensor<?x10x17xf32>
|
||||
// CHECK-LABEL: matmul_batchv2_unknown_dim
|
||||
// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adjoint_lhs = false, adjoint_rhs = false} : (tensor<?x10x15xf32>, tensor<15x17xf32>) -> tensor<?x10x17xf32>
|
||||
}
|
||||
|
||||
|
@ -211,11 +211,6 @@ def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_AddOp $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
|
||||
def : Pat<(TF_AddV2Op $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
|
||||
// When batch size is known, TF BatchMatMul gets unfolded to TFL FullyConnected
|
||||
// with additional ops. In the case of unknown batch size, the match will
|
||||
// fall through to here and convert to TF Lite BatchMatMul.
|
||||
def : Pat<(TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>;
|
||||
def : Pat<(TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>;
|
||||
def : Pat<(TF_SubOp $lhs, $rhs), (TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
|
||||
def : Pat<(TF_MulOp $lhs, $rhs), (TFL_MulOp $lhs, $rhs, TFL_AF_None)>;
|
||||
def : Pat<(TF_RealDivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>;
|
||||
|
@ -44,7 +44,6 @@ enum KernelType {
|
||||
struct OpData {
|
||||
// The index of the temporary tensors where we store transposed LHS/RHS.
|
||||
int scratch_tensor_index;
|
||||
bool rhs_transposed;
|
||||
};
|
||||
|
||||
struct OpContext {
|
||||
@ -64,8 +63,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// Creates two temp tensors to store the transposed LHS and/or RHS if
|
||||
// needed.
|
||||
auto* op_data = new OpData();
|
||||
// If the RHS is constant, we only transpose once.
|
||||
op_data->rhs_transposed = false;
|
||||
context->AddTensors(context, 2, &op_data->scratch_tensor_index);
|
||||
return op_data;
|
||||
}
|
||||
@ -128,12 +125,13 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
scratch_buffer_size));
|
||||
}
|
||||
|
||||
// We need a temp buffer for the RHS if we need to transpose the RHS. We
|
||||
// transpose by default, so that the two inputs (LHS and RHS) are in a proper
|
||||
// layout for our fast matrix multiplication routines. If the transpose flag
|
||||
// is set by the caller, the data is already in the desired layout.
|
||||
const bool rhs_needs_temp = !(op_context->params->adjoint_rhs);
|
||||
if (rhs_needs_temp) {
|
||||
// We need the RHS transposed in the standard case, so if the flag is set,
|
||||
// we do nothing. If the flag is not set, we need this temporary space.
|
||||
// Note: we assume that the RHS is an in-memory tensor. If RHS is from a
|
||||
// constant buffer (e.g. a weights buffer) with allocation type
|
||||
// kTfLiteMmapRo, then this logic must be updated (since a read-only buffer
|
||||
// is in the opposite layout pattern).
|
||||
if (!op_context->params->adjoint_rhs) {
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/1);
|
||||
const TfLiteTensor* rhs = op_context->rhs;
|
||||
int rhs_rank = NumDimensions(rhs);
|
||||
@ -146,11 +144,7 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
scratch_buffer_size->data[rhs_rank - 1] = rhs->dims->data[rhs_rank - 2];
|
||||
|
||||
scratch_buffer->type = op_context->rhs->type;
|
||||
if (IsConstantTensor(op_context->rhs)) {
|
||||
scratch_buffer->allocation_type = kTfLiteArenaRwPersistent;
|
||||
} else {
|
||||
scratch_buffer->allocation_type = kTfLiteArenaRw;
|
||||
}
|
||||
scratch_buffer->allocation_type = kTfLiteArenaRw;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
|
||||
scratch_buffer_size));
|
||||
}
|
||||
@ -250,7 +244,6 @@ RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpContext op_context(context, node);
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor);
|
||||
const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
@ -265,14 +258,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* lhs_tensor =
|
||||
adjoint_lhs ? GetTemporary(context, node, 0) : lhs;
|
||||
if (!adjoint_rhs) {
|
||||
// TODO(b/154760341) Constant tensors should already be transposed, but
|
||||
// we transpose once if necessary for now.
|
||||
if (!(IsConstantTensor(rhs) && op_data->rhs_transposed)) {
|
||||
TransposeRowsColumns<float>(
|
||||
rhs, GetTensorData<float>(rhs), GetTemporary(context, node, 1),
|
||||
GetTensorData<float>(GetTemporary(context, node, 1)));
|
||||
op_data->rhs_transposed = true;
|
||||
}
|
||||
TransposeRowsColumns<float>(
|
||||
rhs, GetTensorData<float>(rhs), GetTemporary(context, node, 1),
|
||||
GetTensorData<float>(GetTemporary(context, node, 1)));
|
||||
}
|
||||
if (adjoint_lhs) {
|
||||
TransposeRowsColumns<float>(
|
||||
|
@ -877,6 +877,7 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
expected_value.numpy(), actual_value[0], decimal=6)
|
||||
|
||||
def testBatchMatMul(self):
|
||||
self.skipTest('BatchMatMulV2 does not support unknown batch size.')
|
||||
input_data_1 = tf.constant(
|
||||
np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
|
||||
input_data_2 = tf.constant(
|
||||
@ -900,8 +901,7 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
actual_value = self._evaluateTFLiteModel(
|
||||
tflite_model, [input_data_1, input_data_2],
|
||||
input_shapes=[([-1, 256, 256], [1, 256, 256])])
|
||||
np.testing.assert_almost_equal(
|
||||
expected_value.numpy(), actual_value[0], decimal=4)
|
||||
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user