Added enum for batched matmul operator.
PiperOrigin-RevId: 329999689 Change-Id: Id2fcacaa14a4a0858691dd391bf7f3c3d99175db
This commit is contained in:
parent
565f670d9f
commit
2cea1ba130
@ -82,6 +82,8 @@ std::string ToString(enum OperationType op) {
|
|||||||
return "batch_normalization";
|
return "batch_normalization";
|
||||||
case OperationType::BATCH_TO_SPACE:
|
case OperationType::BATCH_TO_SPACE:
|
||||||
return "batch_to_space";
|
return "batch_to_space";
|
||||||
|
case OperationType::BATCHED_MATMUL:
|
||||||
|
return "batched_matmul";
|
||||||
case OperationType::CONCAT:
|
case OperationType::CONCAT:
|
||||||
return "concat";
|
return "concat";
|
||||||
case OperationType::CONST:
|
case OperationType::CONST:
|
||||||
@ -195,6 +197,7 @@ OperationType OperationTypeFromString(const std::string& name) {
|
|||||||
{"abs", OperationType::ABS},
|
{"abs", OperationType::ABS},
|
||||||
{"add", OperationType::ADD},
|
{"add", OperationType::ADD},
|
||||||
{"batch_normalization", OperationType::BATCH_NORMALIZATION},
|
{"batch_normalization", OperationType::BATCH_NORMALIZATION},
|
||||||
|
{"batched_matmul", OperationType::BATCHED_MATMUL},
|
||||||
{"concat", OperationType::CONCAT},
|
{"concat", OperationType::CONCAT},
|
||||||
{"const", OperationType::CONST},
|
{"const", OperationType::CONST},
|
||||||
{"convolution_2d", OperationType::CONVOLUTION_2D},
|
{"convolution_2d", OperationType::CONVOLUTION_2D},
|
||||||
|
@ -37,6 +37,7 @@ enum class OperationType {
|
|||||||
ADD,
|
ADD,
|
||||||
BATCH_TO_SPACE,
|
BATCH_TO_SPACE,
|
||||||
BATCH_NORMALIZATION,
|
BATCH_NORMALIZATION,
|
||||||
|
BATCHED_MATMUL,
|
||||||
CONCAT,
|
CONCAT,
|
||||||
CONST,
|
CONST,
|
||||||
CONVOLUTION_2D,
|
CONVOLUTION_2D,
|
||||||
|
@ -404,6 +404,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
|||||||
} break;
|
} break;
|
||||||
case OperationType::BATCH_NORMALIZATION:
|
case OperationType::BATCH_NORMALIZATION:
|
||||||
case OperationType::BATCH_TO_SPACE:
|
case OperationType::BATCH_TO_SPACE:
|
||||||
|
case OperationType::BATCHED_MATMUL:
|
||||||
case OperationType::CONST:
|
case OperationType::CONST:
|
||||||
case OperationType::LSTM:
|
case OperationType::LSTM:
|
||||||
// TODO(b/162763635): implement MeanStddevNormalization for Metal.
|
// TODO(b/162763635): implement MeanStddevNormalization for Metal.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user