diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index e9ab3e00876..9031d54070c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -842,40 +842,6 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ }]; } -def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [ - NoSideEffect, - TFL_OperandHasAtleastRank<0, 2>, - TFL_OperandHasAtleastRank<1, 2>, - SameOperandsAndResultElementType]> { - - let summary = "Batch Matrix Multiply Operator"; - - let description = [{ -Performs a batched matrix multiplication on the inputs. Follows the -conventions of TensorFlow BatchMatMulV2, with support for unknown dimensions -in the batch dimensions and broadcasting. - - Inputs: - `inputs[0]`: required: input LHS - `inputs[1]`: required: input RHS - `adjoint_lhs`: optional: Transpose LHS (default false) - `adjoint_lhs`: optional: Transpose LHS (default false) - }]; - - let arguments = (ins - TFL_TensorOf<[F32]>:$lhs, - TFL_TensorOf<[F32]>:$rhs, - DefaultValuedAttr:$adjoint_lhs, - DefaultValuedAttr:$adjoint_rhs - ); - - let results = (outs - TFL_TensorOf<[F32]>:$output - ); - - let hasOptions = 1; -} - def TFL_GatherOp : TFL_Op<"gather", [ NoSideEffect, SameOperandsAndResultsScale, diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index bb92e9af81b..25ee1d8ba5d 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1497,28 +1497,3 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3 // CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> // CHECK: return [[MUL]] : tensor<3x3xi32> } - -func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> { - %0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : -(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> - return %0 : tensor<10x17xf32> -// CHECK-LABEL: matmul_batch -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adjoint_lhs = false, adjoint_rhs = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> -} - -func @matmul_batchv2(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<2x10x17xf32> { - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : -(tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> - return %0 : tensor<2x10x17xf32> -// CHECK-LABEL: matmul_batchv2 -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adjoint_lhs = false, adjoint_rhs = false} : (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> -} - -func @matmul_batchv2_unknown_dim(%arg0: tensor, %arg1: tensor<15x17xf32>) -> tensor { - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : -(tensor, tensor<15x17xf32>) -> tensor - return %0 : tensor -// CHECK-LABEL: matmul_batchv2_unknown_dim -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adjoint_lhs = false, adjoint_rhs = false} : (tensor, tensor<15x17xf32>) -> tensor -} - diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index d3e497a604f..13ae216dc25 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -211,11 +211,6 @@ def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>; def : Pat<(TF_AddOp $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_AddV2Op $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; -// When batch size is known, TF BatchMatMul gets unfolded to TFL FullyConnected -// with additional ops. In the case of unknown batch size, the match will -// fall through to here and convert to TF Lite BatchMatMul. -def : Pat<(TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; -def : Pat<(TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; def : Pat<(TF_SubOp $lhs, $rhs), (TFL_SubOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_MulOp $lhs, $rhs), (TFL_MulOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_RealDivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc index 46701204993..3e03b13ecbe 100644 --- a/tensorflow/lite/kernels/batch_matmul.cc +++ b/tensorflow/lite/kernels/batch_matmul.cc @@ -44,7 +44,6 @@ enum KernelType { struct OpData { // The index of the temporary tensors where we store transposed LHS/RHS. int scratch_tensor_index; - bool rhs_transposed; }; struct OpContext { @@ -64,8 +63,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Creates two temp tensors to store the transposed LHS and/or RHS if // needed. auto* op_data = new OpData(); - // If the RHS is constant, we only transpose once. - op_data->rhs_transposed = false; context->AddTensors(context, 2, &op_data->scratch_tensor_index); return op_data; } @@ -128,12 +125,13 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, scratch_buffer_size)); } - // We need a temp buffer for the RHS if we need to transpose the RHS. We - // transpose by default, so that the two inputs (LHS and RHS) are in a proper - // layout for our fast matrix multiplication routines. If the transpose flag - // is set by the caller, the data is already in the desired layout. - const bool rhs_needs_temp = !(op_context->params->adjoint_rhs); - if (rhs_needs_temp) { + // We need the RHS transposed in the standard case, so if the flag is set, + // we do nothing. If the flag is not set, we need this temporary space. + // Note: we assume that the RHS is an in-memory tensor. If RHS is from a + // constant buffer (e.g. a weights buffer) with allocation type + // kTfLiteMmapRo, then this logic must be updated (since a read-only buffer + // is in the opposite layout pattern). + if (!op_context->params->adjoint_rhs) { TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/1); const TfLiteTensor* rhs = op_context->rhs; int rhs_rank = NumDimensions(rhs); @@ -146,11 +144,7 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, scratch_buffer_size->data[rhs_rank - 1] = rhs->dims->data[rhs_rank - 2]; scratch_buffer->type = op_context->rhs->type; - if (IsConstantTensor(op_context->rhs)) { - scratch_buffer->allocation_type = kTfLiteArenaRwPersistent; - } else { - scratch_buffer->allocation_type = kTfLiteArenaRw; - } + scratch_buffer->allocation_type = kTfLiteArenaRw; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, scratch_buffer_size)); } @@ -250,7 +244,6 @@ RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) { template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); - OpData* op_data = reinterpret_cast(node->user_data); const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor); const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -265,14 +258,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* lhs_tensor = adjoint_lhs ? GetTemporary(context, node, 0) : lhs; if (!adjoint_rhs) { - // TODO(b/154760341) Constant tensors should already be transposed, but - // we transpose once if necessary for now. - if (!(IsConstantTensor(rhs) && op_data->rhs_transposed)) { - TransposeRowsColumns( - rhs, GetTensorData(rhs), GetTemporary(context, node, 1), - GetTensorData(GetTemporary(context, node, 1))); - op_data->rhs_transposed = true; - } + TransposeRowsColumns( + rhs, GetTensorData(rhs), GetTemporary(context, node, 1), + GetTensorData(GetTemporary(context, node, 1))); } if (adjoint_lhs) { TransposeRowsColumns( diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 70671f36265..763e90f07eb 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -877,6 +877,7 @@ class UnknownShapes(lite_v2_test_util.ModelTest): expected_value.numpy(), actual_value[0], decimal=6) def testBatchMatMul(self): + self.skipTest('BatchMatMulV2 does not support unknown batch size.') input_data_1 = tf.constant( np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32)) input_data_2 = tf.constant( @@ -900,8 +901,7 @@ class UnknownShapes(lite_v2_test_util.ModelTest): actual_value = self._evaluateTFLiteModel( tflite_model, [input_data_1, input_data_2], input_shapes=[([-1, 256, 256], [1, 256, 256])]) - np.testing.assert_almost_equal( - expected_value.numpy(), actual_value[0], decimal=4) + np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0]) if __name__ == '__main__':