diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 69fee0b58d4..ec7f851ba43 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -241,9 +241,15 @@ static void EmitGetBuiltinOpCode(const std::vector &defs, for (const auto *def : defs) { StringRef op_name = def->getName().drop_front(4); + auto operator_name = GetOperatorName(*def); + // TODO(b/149099381): Remove this part after kernels are added as + // builtin op. + if (operator_name == "ASSIGN_VARIABLE" || + operator_name == "READ_VARIABLE") { + continue; + } os << " if (isa(op))\n" - << " return tflite::BuiltinOperator_" << GetOperatorName(*def) - << ";\n"; + << " return tflite::BuiltinOperator_" << operator_name << ";\n"; } os << " return llvm::None;\n" diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index aec6bf067f6..929b6cfd5b7 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -4556,4 +4556,41 @@ the dimension is padded with zeros. ); } + +def TFL_AssignVariableOp : TFL_Op<"assign_variable", []> { + let summary = "Assigns a new value to a variable."; + + let description = [{ +Any ReadVariableOp with a control dependency on this op is guaranteed to return +this value or a subsequent newer value of the variable. + }]; + + let arguments = (ins + // TODO(b/149099381): Remove integer IDs after adding the new variable + // handle type. + TFL_TensorOf<[I32]>:$resource_id, + // TODO(b/149099381): Support other types too. + TFL_TensorOf<[F32]>:$value + ); + + let results = (outs); +} + +def TFL_ReadVariableOp : TFL_Op<"read_variable", []> { + let summary = "Reads variable value."; + + let description = [{ +Read variable data identified by 'resource_id'. + }]; + + let arguments = (ins + // TODO(b/149099381): Remove integer IDs after adding the new variable + // handle type. + TFL_TensorOf<[I32]>:$resource_id + ); + + // TODO(b/149099381): Support other types too. + let results = (outs TFL_TensorOf<[F32]>:$result); +} + #endif // TFL_OPS