Added enum for batched matmul operator.

PiperOrigin-RevId: 329999689
Change-Id: Id2fcacaa14a4a0858691dd391bf7f3c3d99175db
This commit is contained in:
Raman Sarokin 2020-09-03 14:57:41 -07:00 committed by TensorFlower Gardener
parent 565f670d9f
commit 2cea1ba130
3 changed files with 5 additions and 0 deletions

View File

@ -82,6 +82,8 @@ std::string ToString(enum OperationType op) {
return "batch_normalization";
case OperationType::BATCH_TO_SPACE:
return "batch_to_space";
case OperationType::BATCHED_MATMUL:
return "batched_matmul";
case OperationType::CONCAT:
return "concat";
case OperationType::CONST:
@ -195,6 +197,7 @@ OperationType OperationTypeFromString(const std::string& name) {
{"abs", OperationType::ABS},
{"add", OperationType::ADD},
{"batch_normalization", OperationType::BATCH_NORMALIZATION},
{"batched_matmul", OperationType::BATCHED_MATMUL},
{"concat", OperationType::CONCAT},
{"const", OperationType::CONST},
{"convolution_2d", OperationType::CONVOLUTION_2D},

View File

@ -37,6 +37,7 @@ enum class OperationType {
ADD,
BATCH_TO_SPACE,
BATCH_NORMALIZATION,
BATCHED_MATMUL,
CONCAT,
CONST,
CONVOLUTION_2D,

View File

@ -404,6 +404,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
} break;
case OperationType::BATCH_NORMALIZATION:
case OperationType::BATCH_TO_SPACE:
case OperationType::BATCHED_MATMUL:
case OperationType::CONST:
case OperationType::LSTM:
// TODO(b/162763635): implement MeanStddevNormalization for Metal.