Add unidirectional_sequence_rnn stateful ops.

PiperOrigin-RevId: 260766978
This commit is contained in:
Renjie Liu 2019-07-30 12:06:02 -07:00 committed by TensorFlower Gardener
parent e7206a7d8e
commit 900087dd51
2 changed files with 97 additions and 0 deletions
tensorflow/compiler/mlir/lite

View File

@ -884,6 +884,9 @@ bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
} else if (auto tfl =
llvm::dyn_cast<mlir::TFL::UnidirectionalSequenceLSTMOp>(op)) {
operand_indices = tfl.GetStatefulOperands();
} else if (auto tfl =
llvm::dyn_cast<mlir::TFL::UnidirectionalSequenceRNNOp>(op)) {
operand_indices = tfl.GetStatefulOperands();
}
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
}

View File

@ -0,0 +1,94 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "tfl.pseudo_input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.pseudo_input1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.pseudo_input2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tfl.pseudo_input3",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: is_variable: true
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "tfl.unidirectional_sequence_rnn",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1, 2, 3 ],
// CHECK-NEXT: outputs: [ 5 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4 ],
// CHECK-NEXT: outputs: [ 5 ],
// CHECK-NEXT: builtin_options_type: SequenceRNNOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: time_major: true,
// CHECK-NEXT: fused_activation_function: TANH
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
// CHECK-EMPTY:
^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>):
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32>
%1 = "tfl.pseudo_input" (%arg1) : (tensor<4 x f32>) -> tensor<4 x f32>
%2 = "tfl.pseudo_input" (%arg2) : (tensor<4 x f32>) -> tensor<4 x f32>
%3 = "tfl.pseudo_input" (%arg3) : (tensor<4 x f32>) -> tensor<4 x f32>
%4 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%5 = "tfl.unidirectional_sequence_rnn"(%0, %1, %2, %3, %4) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %5 : tensor<4xf32>
}