Add unidirectional sequence rnn op_def to graphdef_to_flatbuffer and also add a e2e test.

PiperOrigin-RevId: 297050448
Change-Id: Ifa7249a5e4585f61ea9833f11ea28a9f2f9e0363
This commit is contained in:
Renjie Liu 2020-02-24 23:15:38 -08:00 committed by TensorFlower Gardener
parent 7ed5e24507
commit 24b2d0252b
2 changed files with 20 additions and 0 deletions

View File

@ -87,6 +87,17 @@ const char kUnidirectionalSequenceLstmOp[] =
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
const char kUnidirectionalSequenceRnnOp[] =
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
"name: 'Bias' type: DT_FLOAT} "
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
"output_arg: { name: "
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
"DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
// conversion mapping for constants defined in TFLite Python API.
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
@ -285,6 +296,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
TF_ASSIGN_OR_RETURN(

View File

@ -249,6 +249,10 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
# Test MLIR-converted model.
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, True)
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
@test_util.enable_control_flow_v2
def testDynamicRnnMultiRnnCell(self):
sess = tf.compat.v1.Session(config=CONFIG)
@ -269,6 +273,10 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
# Test MLIR-converted model.
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, True)
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
if __name__ == "__main__":
test.main()