diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc index b7bf9a2f8c9..94e1f0ad5e6 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.cc +++ b/tensorflow/lite/delegates/gpu/common/operations.cc @@ -82,6 +82,8 @@ std::string ToString(enum OperationType op) { return "batch_normalization"; case OperationType::BATCH_TO_SPACE: return "batch_to_space"; + case OperationType::BATCHED_MATMUL: + return "batched_matmul"; case OperationType::CONCAT: return "concat"; case OperationType::CONST: @@ -195,6 +197,7 @@ OperationType OperationTypeFromString(const std::string& name) { {"abs", OperationType::ABS}, {"add", OperationType::ADD}, {"batch_normalization", OperationType::BATCH_NORMALIZATION}, + {"batched_matmul", OperationType::BATCHED_MATMUL}, {"concat", OperationType::CONCAT}, {"const", OperationType::CONST}, {"convolution_2d", OperationType::CONVOLUTION_2D}, diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h index e92551ffffe..968f75e155b 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.h +++ b/tensorflow/lite/delegates/gpu/common/operations.h @@ -37,6 +37,7 @@ enum class OperationType { ADD, BATCH_TO_SPACE, BATCH_NORMALIZATION, + BATCHED_MATMUL, CONCAT, CONST, CONVOLUTION_2D, diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index 65babe56fca..5773457b642 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -404,6 +404,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, } break; case OperationType::BATCH_NORMALIZATION: case OperationType::BATCH_TO_SPACE: + case OperationType::BATCHED_MATMUL: case OperationType::CONST: case OperationType::LSTM: // TODO(b/162763635): implement MeanStddevNormalization for Metal.