Added enum for batched matmul operator.
PiperOrigin-RevId: 329999689 Change-Id: Id2fcacaa14a4a0858691dd391bf7f3c3d99175db
This commit is contained in:
parent
565f670d9f
commit
2cea1ba130
@ -82,6 +82,8 @@ std::string ToString(enum OperationType op) {
|
||||
return "batch_normalization";
|
||||
case OperationType::BATCH_TO_SPACE:
|
||||
return "batch_to_space";
|
||||
case OperationType::BATCHED_MATMUL:
|
||||
return "batched_matmul";
|
||||
case OperationType::CONCAT:
|
||||
return "concat";
|
||||
case OperationType::CONST:
|
||||
@ -195,6 +197,7 @@ OperationType OperationTypeFromString(const std::string& name) {
|
||||
{"abs", OperationType::ABS},
|
||||
{"add", OperationType::ADD},
|
||||
{"batch_normalization", OperationType::BATCH_NORMALIZATION},
|
||||
{"batched_matmul", OperationType::BATCHED_MATMUL},
|
||||
{"concat", OperationType::CONCAT},
|
||||
{"const", OperationType::CONST},
|
||||
{"convolution_2d", OperationType::CONVOLUTION_2D},
|
||||
|
@ -37,6 +37,7 @@ enum class OperationType {
|
||||
ADD,
|
||||
BATCH_TO_SPACE,
|
||||
BATCH_NORMALIZATION,
|
||||
BATCHED_MATMUL,
|
||||
CONCAT,
|
||||
CONST,
|
||||
CONVOLUTION_2D,
|
||||
|
@ -404,6 +404,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
} break;
|
||||
case OperationType::BATCH_NORMALIZATION:
|
||||
case OperationType::BATCH_TO_SPACE:
|
||||
case OperationType::BATCHED_MATMUL:
|
||||
case OperationType::CONST:
|
||||
case OperationType::LSTM:
|
||||
// TODO(b/162763635): implement MeanStddevNormalization for Metal.
|
||||
|
Loading…
x
Reference in New Issue
Block a user