Add unidirectional sequence rnn op_def to graphdef_to_flatbuffer and also add a e2e test.
PiperOrigin-RevId: 297050448 Change-Id: Ifa7249a5e4585f61ea9833f11ea28a9f2f9e0363
This commit is contained in:
parent
7ed5e24507
commit
24b2d0252b
@ -87,6 +87,17 @@ const char kUnidirectionalSequenceLstmOp[] =
|
||||
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
|
||||
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||
|
||||
const char kUnidirectionalSequenceRnnOp[] =
|
||||
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
|
||||
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
|
||||
"name: 'Bias' type: DT_FLOAT} "
|
||||
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
|
||||
"output_arg: { name: "
|
||||
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
|
||||
"DT_FLOAT} "
|
||||
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||
|
||||
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
||||
// conversion mapping for constants defined in TFLite Python API.
|
||||
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
@ -285,6 +296,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
toco_flags.custom_opdefs().end());
|
||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
|
||||
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
||||
@ -249,6 +249,10 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
|
||||
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
|
||||
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
|
||||
|
||||
# Test MLIR-converted model.
|
||||
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, True)
|
||||
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def testDynamicRnnMultiRnnCell(self):
|
||||
sess = tf.compat.v1.Session(config=CONFIG)
|
||||
@ -269,6 +273,10 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
|
||||
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
|
||||
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
|
||||
|
||||
# Test MLIR-converted model.
|
||||
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, True)
|
||||
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user