1513 lines
80 KiB
C++
1513 lines
80 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
// Unit test for TFLite Bidirectional RNN op.
|
|
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <initializer_list>
|
|
#include <iterator>
|
|
#include <tuple>
|
|
#include <vector>
|
|
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
|
#include "tensorflow/lite/kernels/test_util.h"
|
|
#include "tensorflow/lite/schema/schema_generated.h"
|
|
|
|
namespace tflite {
|
|
namespace {
|
|
|
|
enum class AuxInputMode {
|
|
kNoAuxInput,
|
|
kCrossLinking,
|
|
kNoCrossLinking,
|
|
};
|
|
|
|
using ::testing::ElementsAreArray;
|
|
|
|
static float rnn_input[] = {
|
|
0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
|
|
0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
|
|
-0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
|
|
0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
|
|
0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
|
|
0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
|
|
-0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
|
|
-0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
|
|
0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
|
|
0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
|
|
0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
|
|
-0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
|
|
0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
|
|
-0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
|
|
-0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
|
|
-0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
|
|
0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
|
|
-0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
|
|
-0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
|
|
0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
|
|
-0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
|
|
0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
|
|
0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
|
|
0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
|
|
-0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
|
|
0.93455386, -0.6324693, -0.083922029};
|
|
|
|
static float rnn_golden_fw_output[] = {
|
|
0.496726, 0, 0.965996, 0, 0.0584254, 0,
|
|
0, 0.12315, 0, 0, 0.612266, 0.456601,
|
|
0, 0.52286, 1.16099, 0.0291232,
|
|
|
|
0, 0, 0.524901, 0, 0, 0,
|
|
0, 1.02116, 0, 1.35762, 0, 0.356909,
|
|
0.436415, 0.0355727, 0, 0,
|
|
|
|
0, 0, 0, 0.262335, 0, 0,
|
|
0, 1.33992, 0, 2.9739, 0, 0,
|
|
1.31914, 2.66147, 0, 0,
|
|
|
|
0.942568, 0, 0, 0, 0.025507, 0,
|
|
0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
|
|
0.8158, 1.21805, 0.586239, 0.25427,
|
|
|
|
1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
|
|
0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
|
|
0, 1.22031, 1.30117, 0.495867,
|
|
|
|
0.222187, 0, 0.72725, 0, 0.767003, 0,
|
|
0, 0.147835, 0, 0, 0, 0.608758,
|
|
0.469394, 0.00720298, 0.927537, 0,
|
|
|
|
0.856974, 0.424257, 0, 0, 0.937329, 0,
|
|
0, 0, 0.476425, 0, 0.566017, 0.418462,
|
|
0.141911, 0.996214, 1.13063, 0,
|
|
|
|
0.967899, 0, 0, 0, 0.0831304, 0,
|
|
0, 1.00378, 0, 0, 0, 1.44818,
|
|
1.01768, 0.943891, 0.502745, 0,
|
|
|
|
0.940135, 0, 0, 0, 0, 0,
|
|
0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
|
|
1.30225, 1.59644, 0.70222, 0,
|
|
|
|
0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
|
|
0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
|
|
0.0454298, 0.300267, 0.562784, 0.395095,
|
|
|
|
0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
|
|
0, 0, 0, 0.735363, 0.0759267, 1.91017,
|
|
0.941888, 0, 0, 0,
|
|
|
|
0, 0, 1.5909, 0, 0, 0,
|
|
0, 0.5755, 0, 0.184687, 0, 1.56296,
|
|
0.625285, 0, 0, 0,
|
|
|
|
0, 0, 0.0857888, 0, 0, 0,
|
|
0, 0.488383, 0.252786, 0, 0, 0,
|
|
1.02817, 1.85665, 0, 0,
|
|
|
|
0.00981836, 0, 1.06371, 0, 0, 0,
|
|
0, 0, 0, 0.290445, 0.316406, 0,
|
|
0.304161, 1.25079, 0.0707152, 0,
|
|
|
|
0.986264, 0.309201, 0, 0, 0, 0,
|
|
0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
|
|
0.524981, 1.92076, 2.07013, 0.333244,
|
|
|
|
0.415153, 0.210318, 0, 0, 0, 0,
|
|
0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
|
|
0.628881, 3.58099, 1.49974, 0};
|
|
|
|
static float rnn_golden_bw_output[] = {
|
|
0.496726, 0, 1.00883, 0, 0.0584256, 0, 0,
|
|
0.236412, 0, 0, 0.612267, 0.487726, 0, 0.54883,
|
|
1.16099, 0.0291233, 0, 0, 0.428302, 0, 0,
|
|
0, 0, 1.13262, 0, 1.64415, 0, 0.311249,
|
|
0.570804, 0.259696, 0, 0, 0, 0, 0,
|
|
0.262334, 0, 0, 0, 1.23781, 0, 2.86532,
|
|
0, 0, 1.34389, 2.76409, 0, 0, 1.03969,
|
|
0, 0.00410865, 0, 0.0470295, 0, 0, 0,
|
|
0.371556, 0.27175, 1.36614, 1.63956, 0.683887, 1.06176, 0.719552,
|
|
0.301314, 0.971195, 0, 0.697143, 0, 0.215219, 0.210693,
|
|
0.363027, 0, 0.501283, 0, 1.13399, 0.623774, 0,
|
|
1.09851, 1.33313, 0.470441, 0.210965, 0, 0.664178, 0,
|
|
0.839686, 0, 0, 0.147834, 0, 0, 0,
|
|
0.58786, 0.490128, 0, 0.905806, 0, 0.932134, 0.424257,
|
|
0, 0, 0.860629, 0, 0, 0, 0.476425,
|
|
0, 0.566017, 0.513721, 0.207341, 1.09508, 1.08385, 0,
|
|
0.973787, 0, 0, 0, 0, 0, 0,
|
|
1.20698, 0, 0, 0, 1.56135, 1.12369, 0.99588,
|
|
0.459803, 0, 0.915854, 0, 0, 0, 0,
|
|
0, 0, 2.03206, 0, 0.773264, 0.267228, 1.55012,
|
|
1.202, 1.51611, 0.701202, 0, 0.725088, 0, 0.509069,
|
|
0, 0.671349, 0.581129, 0.343447, 0, 0.107755, 0.611838,
|
|
1.4331, 1.55871, 0.015242, 0.140624, 0.492562, 0.395095, 0.147722,
|
|
0, 0.784925, 0, 1.65477, 0.715257, 0, 0,
|
|
0, 0.685024, 0, 1.89505, 1.00037, 0, 0,
|
|
0, 0, 0, 1.52659, 0, 0, 0,
|
|
0, 0.618583, 0, 0.11115, 0, 1.37194, 0.630225,
|
|
0, 0, 0, 0, 0, 0.0322124, 0,
|
|
0, 0, 0, 0.430834, 0.252786, 0, 0,
|
|
0, 0.991297, 1.98451, 0, 0, 0.111511, 0,
|
|
1.05513, 0, 0, 0, 0, 0, 0,
|
|
0.290445, 0.412559, 0.0429958, 0.256564, 1.27858, 0.289948, 0,
|
|
1.01693, 0.327141, 0, 0, 0, 0, 0,
|
|
1.83508, 0.346248, 0, 0.961535, 0.790026, 0.552203, 2.13457,
|
|
2.19233, 0.333244, 0.316526, 0.179398, 0, 0, 0,
|
|
0, 0, 1.86126, 0, 0.728256, 0.750013, 0.011861,
|
|
0.576383, 3.38891, 1.29273, 0};
|
|
|
|
const std::initializer_list<float> weights = {
|
|
0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
|
|
0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
|
|
0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
|
|
-0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
|
|
-0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
|
|
-0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
|
|
-0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
|
|
0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
|
|
0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
|
|
0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
|
|
-0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
|
|
0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
|
|
-0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
|
|
-0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
|
|
0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
|
|
0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
|
|
0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
|
|
-0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
|
|
0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
|
|
0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
|
|
-0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
|
|
0.277308, 0.415818};
|
|
|
|
static float endtoend_input[] = {
|
|
0.996808, 0.060710, 0.981855, 0.570017, 0.525164, 0.796859, 0.696547,
|
|
0.505925, 0.991844, 0.461208, 0.949371, 0.027624, 0.539236, 0.841854,
|
|
0.915222, 0.538569, 0.069375, 0.237905, 0.903700, 0.441703, 0.536196,
|
|
0.402724, 0.761635, 0.025063, 0.082592, 0.688245, 0.239310, 0.256931,
|
|
0.658900, 0.105695, 0.301983, 0.655708, 0.166405, 0.283837, 0.225725,
|
|
0.691569, 0.080696, 0.922272, 0.197494, 0.072540, 0.383481, 0.146865,
|
|
0.100163, 0.922717, 0.988720, 0.015386, 0.461286, 0.058095, 0.253290,
|
|
0.364986, 0.499797, 0.789487, 0.767709, 0.261433, 0.814549, 0.850302,
|
|
0.949678, 0.053859, 0.107233, 0.608577, 0.159554, 0.409215, 0.264285,
|
|
0.325960, 0.693053, 0.490011, 0.017529, 0.773749, 0.412283, 0.215023,
|
|
0.846288, 0.795764, 0.361889, 0.946452, 0.718481, 0.350608, 0.961837,
|
|
0.179767, 0.408703, 0.215128, 0.544753, 0.908500, 0.004614, 0.312462,
|
|
0.169933, 0.819163, 0.162764, 0.119611, 0.873022, 0.269997, 0.728188,
|
|
0.032576, 0.679212, 0.992474, 0.358536, 0.372265, 0.482484, 0.376065,
|
|
0.146014, 0.894767, 0.591088, 0.992302, 0.690531, 0.952977, 0.938754,
|
|
0.409012, 0.303585, 0.900591, 0.588780, 0.712287, 0.115719, 0.133533,
|
|
0.620788, 0.120334, 0.445995, 0.790720, 0.939497, 0.608759, 0.910331,
|
|
0.812519, 0.878756, 0.638519, 0.845096, 0.557968, 0.630993, 0.203632,
|
|
0.930233, 0.113477, 0.579697, 0.076247, 0.008244, 0.170785, 0.068549,
|
|
0.698776, 0.123761, 0.007303, 0.107788, 0.427346, 0.907894, 0.696568,
|
|
0.139633, 0.023613, 0.830100, 0.760421, 0.143947, 0.276096, 0.551141,
|
|
0.083444, 0.884855, 0.461472, 0.895963, 0.763611, 0.099992, 0.741059,
|
|
0.321579, 0.730984, 0.944691, 0.251812, 0.844461, 0.524388, 0.328059,
|
|
0.852706, 0.695172, 0.396607, 0.551482, 0.818934, 0.403910, 0.659270,
|
|
0.246280, 0.311804, 0.355838, 0.385913, 0.335418, 0.185938, 0.146334,
|
|
0.479364, 0.462034, 0.697475, 0.562808, 0.346888, 0.158948, 0.458771,
|
|
0.110499, 0.258939, 0.199830, 0.432078, 0.989924, 0.144521, 0.683890,
|
|
0.834385, 0.668908, 0.011949, 0.687091, 0.364081, 0.408556, 0.238572,
|
|
0.183015, 0.812466, 0.897842, 0.429294, 0.124271, 0.253680, 0.815207,
|
|
0.459688, 0.439618, 0.961541, 0.939053, 0.901651, 0.659016, 0.501861,
|
|
0.248539, 0.817964, 0.960632, 0.359038, 0.076903, 0.160462, 0.791117,
|
|
0.066826, 0.304983, 0.475007, 0.901211, 0.973891, 0.486955, 0.588302,
|
|
0.337972, 0.895512, 0.826874, 0.520987, 0.707978, 0.724716, 0.950281,
|
|
0.832249, 0.978396, 0.765488, 0.291937, 0.418014, 0.727029, 0.230990,
|
|
0.319665, 0.386045, 0.732850, 0.568204, 0.204009, 0.693482, 0.927242,
|
|
0.280912, 0.853944, 0.718359, 0.347738, 0.158927, 0.193366, 0.248950,
|
|
0.132818, 0.680321, 0.837252, 0.470790, 0.575833, 0.664126, 0.991777,
|
|
0.283811, 0.388843, 0.942058, 0.116060, 0.367239, 0.707546, 0.407997,
|
|
0.785253, 0.434575, 0.638986, 0.104917, 0.820620, 0.371837, 0.673121,
|
|
0.024629, 0.065319, 0.600363, 0.305541, 0.919263, 0.318722, 0.653279,
|
|
0.078190, 0.512088, 0.902229, 0.211009, 0.192409, 0.739480, 0.681799,
|
|
0.768242, 0.403607, 0.673576, 0.052052, 0.792450, 0.615634, 0.168112,
|
|
0.159689, 0.323180, 0.576109, 0.944941, 0.757755, 0.215095, 0.049858,
|
|
0.578375, 0.586932, 0.722979, 0.603003, 0.652251, 0.323343, 0.908544,
|
|
0.571514, 0.642065, 0.561823, 0.649704, 0.154153, 0.464051, 0.860713,
|
|
0.346562, 0.203532, 0.542512, 0.114804, 0.607139, 0.216088, 0.166856,
|
|
0.399588, 0.831722, 0.334968, 0.559277, 0.154902, 0.911077, 0.504218,
|
|
0.912656, 0.126172, 0.554076, 0.491031, 0.713104, 0.277055, 0.094034,
|
|
0.365355, 0.600398, 0.002578, 0.936869, 0.242463, 0.564401, 0.586574,
|
|
0.396616, 0.028452, 0.447287, 0.743178, 0.231984, 0.989799, 0.857982,
|
|
0.839122, 0.205887, 0.024838, 0.238711, 0.037608, 0.359806, 0.797987,
|
|
0.192510, 0.270883, 0.302205, 0.105166, 0.397055, 0.856281, 0.596197,
|
|
0.110160, 0.133336, 0.690231, 0.475515, 0.733734, 0.692809, 0.412384,
|
|
0.976196, 0.257209, 0.998958, 0.372812, 0.285661, 0.446245, 0.115990,
|
|
0.517645, 0.436044, 0.973972, 0.356767, 0.641930, 0.998810, 0.595478,
|
|
0.679539, 0.358617, 0.393465, 0.872049, 0.629500, 0.695670, 0.977215,
|
|
0.026555, 0.551951, 0.573412, 0.136715, 0.685287, 0.263643, 0.612229,
|
|
0.419020, 0.956451, 0.024613, 0.395216, 0.213661, 0.023572, 0.768029,
|
|
0.499322, 0.469816, 0.884019, 0.016967, 0.905860, 0.857991, 0.373734,
|
|
0.547791, 0.856802, 0.969211, 0.227330, 0.215418, 0.362676, 0.099378,
|
|
0.844918, 0.058346, 0.076594, 0.871473, 0.610297, 0.650006, 0.008188,
|
|
0.295583, 0.913648, 0.620417, 0.714603, 0.870100, 0.645031, 0.109820,
|
|
0.083760, 0.668602, 0.877849, 0.583082, 0.138419, 0.761868, 0.600049,
|
|
0.044279, 0.619859, 0.973783, 0.592069, 0.476661, 0.942994, 0.819399,
|
|
0.692079, 0.305670, 0.918778, 0.536997, 0.364016, 0.995371, 0.408470,
|
|
0.974313, 0.645377, 0.416658, 0.269896, 0.559025, 0.037075, 0.984499,
|
|
0.429125, 0.682105, 0.094319, 0.512885, 0.350707, 0.972168, 0.095967,
|
|
0.489126, 0.734035, 0.696016, 0.533405, 0.353894, 0.669799, 0.125474,
|
|
0.830555, 0.612793, 0.944873, 0.522634, 0.918463, 0.863651, 0.059631,
|
|
0.282479, 0.859022, 0.468101, 0.256791, 0.504398, 0.884758, 0.526687,
|
|
0.063423, 0.921833, 0.511186, 0.492548, 0.603939, 0.605505, 0.005433,
|
|
0.954646, 0.577673, 0.101400, 0.443772, 0.311708, 0.797417, 0.977176,
|
|
0.665602, 0.467216, 0.102650, 0.496157, 0.080009, 0.047524, 0.018791,
|
|
0.998471, 0.911174, 0.078422, 0.280950, 0.770196, 0.546523, 0.537741,
|
|
0.274594, 0.431281, 0.064428, 0.338017, 0.353115, 0.575615, 0.830565,
|
|
0.957053, 0.181120, 0.835998, 0.911699, 0.758793, 0.937398, 0.355471,
|
|
0.070501, 0.734815, 0.332647, 0.736103, 0.202031, 0.435297, 0.232261,
|
|
0.282039, 0.482821, 0.251052, 0.280511, 0.393995, 0.329474, 0.561460,
|
|
0.164191, 0.875997, 0.099202, 0.438785, 0.307278, 0.163630, 0.776802,
|
|
0.660393, 0.739244, 0.607367, 0.617446, 0.920364, 0.443365, 0.529145,
|
|
0.679157, 0.380763, 0.884616, 0.749658, 0.115578, 0.217263, 0.485761,
|
|
0.317609, 0.652560, 0.718021, 0.599648, 0.135381, 0.969073, 0.880159,
|
|
0.529376, 0.298547, 0.441619, 0.693567, 0.174544, 0.540821, 0.132351,
|
|
0.481822, 0.704450, 0.909153, 0.142215, 0.443695, 0.516520, 0.759661,
|
|
0.364059, 0.959885, 0.288806, 0.043216, 0.340648, 0.173422, 0.792874,
|
|
0.456226, 0.390685, 0.278634, 0.773834, 0.043245, 0.996656, 0.373483,
|
|
0.178625, 0.965729, 0.253641, 0.708001, 0.264276, 0.695260, 0.401568,
|
|
0.438820, 0.236081, 0.533919, 0.920642, 0.940531, 0.443072, 0.062857,
|
|
0.384226, 0.959592, 0.822518, 0.748285, 0.919477, 0.111325, 0.791501,
|
|
0.260124, 0.284747, 0.584375, 0.716350, 0.675431, 0.863009, 0.490184,
|
|
0.718676, 0.859665, 0.863666, 0.897301, 0.825393, 0.117308, 0.605302,
|
|
0.089669, 0.812568, 0.006870, 0.528489, 0.048649, 0.540788, 0.449131,
|
|
0.989180, 0.983860, 0.511988, 0.373407, 0.943452, 0.334506, 0.121692,
|
|
0.862929, 0.445831, 0.913193, 0.123053, 0.730578, 0.497568, 0.839402,
|
|
0.406009, 0.360577, 0.329586, 0.124685, 0.220241, 0.193253, 0.021986,
|
|
0.045634, 0.310560, 0.627288, 0.135303, 0.123128, 0.634158, 0.663792,
|
|
0.171777, 0.174946, 0.112923, 0.160958, 0.158806, 0.624911, 0.534364,
|
|
0.102259, 0.959418, 0.656056, 0.965187, 0.405249, 0.569249, 0.088240,
|
|
0.135827, 0.066817, 0.927642, 0.541836, 0.427393, 0.257229, 0.666520,
|
|
0.647634, 0.450481, 0.688506, 0.693269, 0.761042, 0.315794, 0.828572,
|
|
0.884170, 0.949952, 0.492364, 0.055947, 0.124898, 0.605288, 0.216905,
|
|
0.283705, 0.230199, 0.751269, 0.385963, 0.189616, 0.407326, 0.351151,
|
|
0.594865, 0.976575, 0.439391, 0.730692, 0.043392, 0.367033, 0.272527,
|
|
0.470785, 0.624261, 0.939048, 0.118419, 0.074743, 0.627554, 0.811688,
|
|
0.835784, 0.943348, 0.640260, 0.719954, 0.893300, 0.132625, 0.775901,
|
|
0.018199, 0.737913, 0.992806, 0.301903, 0.968111, 0.744076, 0.687867,
|
|
0.157728, 0.151401, 0.039017, 0.752593, 0.127976, 0.478408, 0.483284,
|
|
0.171368, 0.845441, 0.755811, 0.642153, 0.469702, 0.694859, 0.760572,
|
|
0.544445, 0.322413, 0.572260, 0.380229, 0.265761, 0.212521, 0.100183,
|
|
0.159062, 0.345146, 0.876084, 0.177261, 0.083058, 0.868891, 0.479164,
|
|
0.051169, 0.612966, 0.167030, 0.208897, 0.764367, 0.206048, 0.961490,
|
|
0.892343, 0.684456, 0.444774, 0.063711, 0.529896, 0.200585, 0.705863,
|
|
0.999598, 0.895444, 0.466435, 0.544043, 0.217857, 0.038696, 0.924272,
|
|
0.483618, 0.251217, 0.024455, 0.642680, 0.596362, 0.900539, 0.819941,
|
|
0.679420, 0.769430, 0.299105, 0.730590, 0.382396, 0.466135, 0.939487,
|
|
0.146763, 0.672183, 0.900977, 0.039106, 0.356638, 0.345750, 0.102817,
|
|
0.886535, 0.546336, 0.808681, 0.886133, 0.441780, 0.275116, 0.430176,
|
|
0.659637, 0.313812, 0.354448, 0.143255, 0.565028, 0.378903, 0.785935,
|
|
0.161391, 0.279443, 0.605876, 0.840811, 0.048873, 0.904980, 0.571401,
|
|
0.431269, 0.371115, 0.510887, 0.578032, 0.043298, 0.411864, 0.617138,
|
|
0.399936, 0.757614, 0.719955, 0.286471, 0.303950, 0.528636, 0.172604,
|
|
0.745730, 0.803752, 0.602780, 0.405367, 0.117564, 0.957228, 0.548622,
|
|
0.682592, 0.336131, 0.334557, 0.843983, 0.615574, 0.940433, 0.684794,
|
|
0.664447, 0.845413, 0.256194, 0.095715, 0.216529, 0.767082, 0.673747,
|
|
0.259827, 0.178946, 0.290885, 0.659763, 0.936560, 0.010840, 0.946234,
|
|
0.240510, 0.539476, 0.118838, 0.986240, 0.343228, 0.721618, 0.391606,
|
|
0.460792, 0.678846, 0.940228, 0.143384, 0.014977, 0.274785, 0.987367,
|
|
0.630551, 0.215218, 0.672161, 0.294998, 0.060631, 0.928355, 0.390713,
|
|
0.277160, 0.695436, 0.064460, 0.536987, 0.874382, 0.355345, 0.196751,
|
|
0.810942, 0.366185, 0.142985, 0.051452, 0.905661, 0.261823, 0.037691,
|
|
0.248889, 0.983441, 0.429297, 0.709681, 0.662286, 0.369525, 0.853066,
|
|
0.677263, 0.644310, 0.840433, 0.307814, 0.859528, 0.512593, 0.602812,
|
|
0.920160, 0.440948, 0.993525, 0.197320, 0.136384, 0.057984, 0.734307,
|
|
0.010766, 0.413329, 0.931058, 0.821707, 0.779514, 0.074043, 0.873159,
|
|
0.685175, 0.335865, 0.910850, 0.934065, 0.319306, 0.340147, 0.643746,
|
|
0.981592, 0.709673, 0.496812, 0.658856, 0.353983, 0.337245, 0.966670,
|
|
0.213511, 0.849838, 0.569482, 0.133671, 0.290786, 0.563007, 0.330991,
|
|
0.427170, 0.620991, 0.065299, 0.437936, 0.034320, 0.996356, 0.259643,
|
|
0.813834, 0.070399, 0.132802, 0.499009, 0.406265, 0.043652, 0.433074,
|
|
0.725570, 0.383800, 0.076820, 0.707163, 0.093473, 0.573632, 0.366018,
|
|
0.447456, 0.910877, 0.332688, 0.660967, 0.760714, 0.902170, 0.794638,
|
|
0.051500, 0.465177, 0.125630, 0.478670, 0.086168, 0.190928, 0.916605,
|
|
0.120488, 0.187285, 0.176248, 0.934322, 0.257684, 0.309050, 0.433331,
|
|
0.663949, 0.352703, 0.866405, 0.389519, 0.736502, 0.943226, 0.096682,
|
|
0.829975, 0.516858, 0.462700, 0.277430, 0.427734, 0.795388, 0.938398,
|
|
0.188449, 0.697558, 0.733036, 0.239948, 0.162735, 0.858666, 0.718618,
|
|
0.248903, 0.049594, 0.635223, 0.369391, 0.236879, 0.811472, 0.303713,
|
|
0.494563, 0.120522, 0.737044, 0.158511, 0.473225, 0.603450, 0.548030,
|
|
0.209727, 0.546675, 0.644712, 0.039702, 0.063533, 0.107412, 0.317132,
|
|
0.491267, 0.902800, 0.255530, 0.679716, 0.600359, 0.988566, 0.919664,
|
|
0.763094, 0.847232, 0.638283, 0.011997, 0.896825, 0.273506, 0.381388,
|
|
0.133704, 0.084978, 0.685101, 0.628267, 0.205500, 0.422145, 0.786778,
|
|
0.678725, 0.025595, 0.334808, 0.888452, 0.572271, 0.979520, 0.928154,
|
|
0.635804, 0.086932, 0.245286, 0.127071, 0.989732, 0.500816, 0.806787,
|
|
0.590091, 0.489382, 0.726451, 0.353185, 0.336614, 0.364734, 0.365182,
|
|
0.233439, 0.638240, 0.746570, 0.367143, 0.723218, 0.431671, 0.995410,
|
|
0.928718, 0.853816, 0.782188, 0.607442, 0.879411, 0.116995, 0.495894,
|
|
0.451682, 0.096515, 0.424048, 0.087485, 0.183447, 0.669334, 0.214556,
|
|
0.173179, 0.170151, 0.021343, 0.763269, 0.659533, 0.747794, 0.116454,
|
|
0.996147, 0.112528, 0.481635, 0.229586, 0.750768, 0.228205, 0.596730,
|
|
0.473985, 0.659876, 0.592139, 0.402703, 0.513692, 0.374327, 0.010145,
|
|
0.393103, 0.491322, 0.506039, 0.844785, 0.587837, 0.930088, 0.932270,
|
|
0.771284, 0.599422, 0.146826, 0.944463, 0.769573, 0.168169, 0.707732,
|
|
0.429106, 0.915964, 0.824186, 0.425253, 0.028492, 0.305821, 0.654839,
|
|
0.779259, 0.534026, 0.251569, 0.253245, 0.193901, 0.843708, 0.655947,
|
|
0.707593, 0.218035, 0.666093, 0.100696, 0.709357, 0.172132, 0.945481,
|
|
0.297195, 0.102220, 0.877751, 0.068479, 0.701642, 0.024577, 0.012941,
|
|
0.471215, 0.192747, 0.720673, 0.900321, 0.108710, 0.544859, 0.325574,
|
|
0.137202, 0.850679, 0.980413, 0.916462, 0.384705, 0.231982, 0.169706,
|
|
0.578607, 0.075690, 0.825654, 0.286200, 0.293725, 0.491746, 0.386896,
|
|
0.003083, 0.663878, 0.332377, 0.300278, 0.766098, 0.210128, 0.368756,
|
|
0.467740, 0.234705, 0.381697, 0.938955, 0.427451, 0.102370, 0.839275,
|
|
0.536162, 0.647229, 0.164849, 0.673364, 0.497908, 0.145262, 0.589825,
|
|
0.882613, 0.377244, 0.759532, 0.461220, 0.452934, 0.585185, 0.747420,
|
|
0.746660, 0.076932, 0.134316, 0.749743, 0.740810, 0.466692, 0.050020,
|
|
0.506908, 0.676820, 0.418776, 0.974648, 0.911525, 0.800474, 0.913602,
|
|
0.338976, 0.902844, 0.752878, 0.875138, 0.550072, 0.917727, 0.548502,
|
|
0.047981, 0.062989, 0.138327, 0.930594, 0.440233, 0.897859, 0.391814,
|
|
0.893168, 0.483044, 0.139234, 0.639828, 0.559975, 0.273549, 0.389570,
|
|
0.300785, 0.740242, 0.439590, 0.807693, 0.417062, 0.858367, 0.782341,
|
|
0.328586, 0.658840, 0.695943, 0.667562, 0.561684, 0.448821, 0.542700,
|
|
0.111756, 0.366548, 0.091202, 0.159737, 0.429537, 0.229529, 0.090331,
|
|
0.869770, 0.127388, 0.482145, 0.762938, 0.610432, 0.621379, 0.402765,
|
|
0.170407, 0.894928, 0.792336, 0.471192, 0.635170, 0.231926, 0.278886,
|
|
0.052232, 0.090293, 0.061226, 0.380818, 0.749133, 0.757170, 0.048380,
|
|
0.310817, 0.205990, 0.591080, 0.422573, 0.572538, 0.682282, 0.582310,
|
|
0.002075, 0.911812, 0.672641, 0.871845, 0.039199, 0.154786, 0.634783,
|
|
0.649631, 0.776165, 0.037548, 0.820038, 0.671093, 0.829884, 0.291231,
|
|
0.306263, 0.061810, 0.570116, 0.358495, 0.152103, 0.631343, 0.739313,
|
|
0.901236, 0.388512, 0.787693, 0.212053, 0.594503, 0.378773, 0.634626,
|
|
0.167040, 0.061056, 0.216937, 0.169115, 0.972867, 0.889578, 0.040960,
|
|
0.012067, 0.044364, 0.675743, 0.661698, 0.820529, 0.713291, 0.481736,
|
|
0.491623, 0.543175, 0.772966, 0.797886, 0.604985, 0.343083, 0.156380,
|
|
0.757088, 0.974425, 0.895693, 0.658324, 0.362938, 0.683386, 0.870376,
|
|
0.957440, 0.062159, 0.505002, 0.124481, 0.123215, 0.721939, 0.293596,
|
|
0.096082, 0.611517, 0.334556, 0.108149, 0.655881, 0.010299, 0.769846,
|
|
0.476411, 0.723590, 0.251582, 0.968033, 0.266765, 0.024548, 0.765919,
|
|
0.871750, 0.367631, 0.922299, 0.628838, 0.342056, 0.817992, 0.287162,
|
|
0.704994, 0.501378, 0.157538, 0.662434, 0.563537, 0.662541, 0.786915,
|
|
0.686752, 0.384480, 0.080511, 0.782834, 0.995997, 0.415067, 0.890983,
|
|
0.651878, 0.425365, 0.660829, 0.128289, 0.148956, 0.912411, 0.096322,
|
|
0.415721, 0.936959, 0.862241, 0.287471, 0.304590, 0.784540, 0.916309,
|
|
0.646646, 0.602533, 0.203471, 0.351640, 0.103911, 0.361009, 0.014074,
|
|
0.667448, 0.023550, 0.800989, 0.354200, 0.408030, 0.881500, 0.137034,
|
|
0.404026, 0.296566, 0.028017, 0.055904, 0.721932, 0.688846, 0.184193,
|
|
0.870887, 0.601257, 0.280515, 0.286608, 0.538216, 0.142755, 0.574079,
|
|
0.842806, 0.927296, 0.490388, 0.489452, 0.529828, 0.693859, 0.841092,
|
|
0.633739, 0.054869, 0.855167, 0.301187, 0.078419, 0.656156, 0.655388,
|
|
0.486448, 0.537656, 0.792422, 0.890475, 0.834222, 0.820439, 0.946379,
|
|
0.556153, 0.509285, 0.130571, 0.427041, 0.110542, 0.411086, 0.713648,
|
|
0.648758, 0.553842, 0.287727, 0.491563, 0.481137, 0.778116, 0.981015,
|
|
0.010966, 0.471975, 0.822107, 0.644705, 0.526844, 0.677274, 0.945892,
|
|
0.605263, 0.333430, 0.601280, 0.091711, 0.871086, 0.393702, 0.982186,
|
|
0.705307, 0.214141, 0.928564, 0.261461, 0.723426, 0.059136, 0.688501,
|
|
0.833968, 0.470222, 0.402150, 0.482725, 0.024063, 0.689877, 0.974289,
|
|
0.505201, 0.467993, 0.955304, 0.516166, 0.939968, 0.777411, 0.160871,
|
|
0.466812, 0.454685, 0.106763, 0.072075, 0.788115, 0.708043, 0.163786,
|
|
0.659201, 0.101744, 0.145971, 0.364508, 0.315885, 0.074536, 0.625969,
|
|
0.039311, 0.133672, 0.314471, 0.873279, 0.603893, 0.716620, 0.356004,
|
|
0.627957, 0.406498, 0.330292, 0.133157, 0.874490, 0.285596, 0.649324,
|
|
0.814458, 0.063007, 0.810195, 0.281270, 0.517693, 0.916958, 0.353345,
|
|
0.305808, 0.625000, 0.517131, 0.965009, 0.726745, 0.663102, 0.329518,
|
|
0.042630, 0.737638, 0.955487, 0.081940, 0.871310, 0.269957, 0.955219,
|
|
0.475203, 0.986578, 0.311223, 0.103160, 0.393075, 0.641515, 0.236317,
|
|
0.267566, 0.927112, 0.885641, 0.082024, 0.990119, 0.695835, 0.363295,
|
|
0.507812, 0.612793, 0.716640, 0.813620, 0.237793, 0.233770, 0.778629,
|
|
0.964538, 0.896872, 0.108147, 0.007167, 0.634510, 0.063633, 0.089108,
|
|
0.505820, 0.333591, 0.044327, 0.981023, 0.320168, 0.355550, 0.084182,
|
|
0.713244, 0.997065, 0.320499, 0.980810, 0.924177, 0.206140, 0.062834,
|
|
0.914296, 0.901975, 0.426129, 0.422107, 0.514768, 0.142768, 0.235727,
|
|
0.752561, 0.376539, 0.014356, 0.717099, 0.273411, 0.122502, 0.724266,
|
|
0.907921, 0.186136, 0.813374, 0.413741, 0.519726, 0.857701, 0.394764,
|
|
0.839895, 0.213251, 0.478946, 0.553139, 0.210317, 0.799446, 0.533948,
|
|
0.134493, 0.005586, 0.596782, 0.048789, 0.907561, 0.022911, 0.470896,
|
|
0.422329, 0.165679, 0.706623, 0.174890, 0.542218, 0.720979, 0.891989,
|
|
0.815629, 0.843481, 0.616255, 0.723551, 0.029617, 0.429630, 0.137292,
|
|
0.549343, 0.287331, 0.532056, 0.389238, 0.500583, 0.011002, 0.942377,
|
|
0.710899, 0.810448, 0.476326, 0.845392, 0.816033, 0.073108, 0.894181,
|
|
0.723594, 0.096019, 0.365077, 0.145923, 0.261699, 0.071700, 0.320813,
|
|
0.803917, 0.792679, 0.212802, 0.619546, 0.636160, 0.829057, 0.343096,
|
|
0.665777, 0.258687, 0.480388, 0.215121, 0.546018, 0.012444, 0.604359,
|
|
0.046601, 0.023446, 0.546736, 0.757500, 0.833893, 0.023062, 0.602892,
|
|
0.649927, 0.096170, 0.497074, 0.373521, 0.192189, 0.862151, 0.519444,
|
|
0.453887, 0.933851, 0.840257, 0.257804, 0.726531, 0.053058, 0.877350,
|
|
0.362691, 0.882115, 0.220446, 0.028468, 0.140802, 0.700834, 0.243589,
|
|
0.686821, 0.713278, 0.847948, 0.733421, 0.736723, 0.394684, 0.490921,
|
|
0.570617, 0.417746, 0.093813, 0.220543, 0.513916, 0.590887, 0.594064,
|
|
0.706105, 0.453038, 0.113508, 0.159992, 0.386889, 0.953765, 0.417796,
|
|
0.113420, 0.006823, 0.295146, 0.476111, 0.888938, 0.515592, 0.504579,
|
|
0.029741, 0.216426, 0.748168, 0.716561, 0.929703, 0.596117, 0.449982,
|
|
0.666427, 0.990801, 0.940903, 0.237043, 0.408547, 0.034717, 0.457587,
|
|
0.922463, 0.625603, 0.051651, 0.628568, 0.078641, 0.165159, 0.788560,
|
|
0.465530, 0.118923, 0.206356, 0.578950, 0.125746, 0.501502, 0.055060,
|
|
0.014685, 0.017094, 0.559640, 0.044425, 0.233519, 0.307808, 0.760986,
|
|
0.163223, 0.903925, 0.210969, 0.829650, 0.894726, 0.151872, 0.066693,
|
|
0.303273, 0.186589, 0.524279, 0.225736, 0.812192, 0.575930, 0.854304,
|
|
0.890833, 0.741089, 0.642864, 0.356363, 0.860012, 0.849220, 0.935313,
|
|
0.985758, 0.350722, 0.990373, 0.000443, 0.367815, 0.550013, 0.044868,
|
|
0.601335, 0.857820, 0.805855, 0.764557, 0.761745, 0.016823, 0.594207,
|
|
0.656471, 0.168696, 0.660900, 0.959744, 0.355284, 0.185179, 0.185480,
|
|
0.167477, 0.761110, 0.039784, 0.058310, 0.502199, 0.682648, 0.414673,
|
|
0.362211, 0.531868, 0.349985, 0.347969, 0.882589, 0.340358, 0.348412,
|
|
0.250404, 0.890371, 0.393280, 0.851739, 0.748191, 0.199135, 0.616297,
|
|
0.509936, 0.215958, 0.210504, 0.166407, 0.384654, 0.871404, 0.126151,
|
|
0.739938, 0.056583, 0.311631, 0.907415, 0.817693, 0.351415, 0.965724,
|
|
0.319891, 0.034062, 0.380397, 0.682102, 0.565930, 0.730382, 0.030072,
|
|
0.448519, 0.070741, 0.378484, 0.698924, 0.961112, 0.771764, 0.550663,
|
|
0.709303, 0.970899, 0.166959, 0.219239, 0.186857, 0.377463, 0.385647,
|
|
0.571511, 0.248867, 0.511798, 0.311449, 0.305450, 0.823429, 0.218864,
|
|
0.123142, 0.174844, 0.184588, 0.443034, 0.208906, 0.564986, 0.125136,
|
|
0.774836, 0.295368, 0.155207, 0.223355, 0.366109, 0.533691, 0.922279,
|
|
0.327221, 0.305455, 0.472942, 0.036524, 0.276354, 0.639901, 0.255763,
|
|
0.463211, 0.017364, 0.641410, 0.034722, 0.266231, 0.153207, 0.346171,
|
|
0.571680, 0.976636, 0.565036, 0.694822, 0.151480, 0.749624, 0.137856,
|
|
0.360386, 0.314610, 0.262992, 0.135222, 0.609978, 0.418200, 0.358578,
|
|
0.976087, 0.951891, 0.280856, 0.303307, 0.257346, 0.753798, 0.339831,
|
|
0.533700, 0.393699, 0.595594, 0.996911, 0.411063, 0.237003, 0.031634,
|
|
0.677294, 0.390211, 0.377805, 0.248974, 0.366847, 0.942841, 0.943796,
|
|
0.518327, 0.692465, 0.081653, 0.878713, 0.007074, 0.344645, 0.013936,
|
|
0.617052, 0.762845, 0.372513, 0.593138, 0.714736, 0.653370, 0.896446,
|
|
0.972082, 0.407168, 0.236276, 0.505782, 0.800867, 0.831870, 0.502693,
|
|
0.211930, 0.068873, 0.534327, 0.889224, 0.459084, 0.912132, 0.138197,
|
|
0.825931, 0.854972, 0.081994, 0.344259, 0.547437, 0.163646, 0.222972,
|
|
0.554511, 0.508291, 0.236908, 0.171563, 0.271135, 0.609421, 0.764701,
|
|
0.985871, 0.262790, 0.661147, 0.957953, 0.669958, 0.897423, 0.463734,
|
|
0.470825, 0.729293, 0.966427, 0.682755, 0.798166, 0.500754, 0.571978,
|
|
0.257251, 0.412886, 0.710176, 0.083182, 0.267858, 0.792169, 0.427441,
|
|
0.815295, 0.955815, 0.650413, 0.369805, 0.464106, 0.887320, 0.541368,
|
|
0.735242, 0.496741, 0.306069, 0.721113, 0.759531, 0.967216, 0.679065,
|
|
0.429489, 0.864639, 0.142799, 0.900314, 0.593932, 0.109227, 0.583069,
|
|
0.392098, 0.609981, 0.155047, 0.649349, 0.022867, 0.865222, 0.732531,
|
|
0.290725, 0.657392, 0.159972, 0.106019, 0.613207, 0.810384, 0.475824,
|
|
0.077313, 0.697704, 0.017192, 0.812555};
|
|
|
|
static float golden_endtoend_output[] = {
|
|
-1.881211, -0.028385, -3.585066, 1.939770, -3.461155, 1.280415, -4.408978,
|
|
0.608663, -2.704937, 1.859742, -5.777429, 2.691839, -1.049012, 1.640870,
|
|
-4.856245, 1.604236, 0.992707, 0.422858, -4.307465, 1.887332, -0.884831,
|
|
-0.154277, -2.634801, 0.586827, -1.849960, 1.399608, -4.531559, 1.943591,
|
|
0.271676, -2.893054, -2.066826, 0.235467, -1.248263, -1.164534, -2.640174,
|
|
-0.112878, -4.386484, 1.253024, -4.135623, 1.068984, -0.043579, -0.832957,
|
|
-3.257258, -0.514396, -1.651174, 0.638630, -4.364372, 1.548441, -0.289455,
|
|
0.539845, -4.097627, 0.635001, -0.465071, -0.927701, -2.481498, 0.356616,
|
|
-2.355012, 0.728806, -3.340283, 1.609038, -4.786268, -0.532272, -1.886150,
|
|
0.254797, 0.746620, -1.657134, -3.264265, 0.525551, -1.756837, 0.845446,
|
|
-5.572190, 1.715797, -2.856942, 3.394245, -5.803662, 2.281806, -3.014739,
|
|
2.616136, -4.728482, 1.659984, -2.106307, 2.711709, -6.173832, 1.352869,
|
|
-0.038035, 0.107619, -4.279774, 2.341930, -0.980413, -0.119538, -4.049717,
|
|
1.172128, -3.477744, 2.602274, -6.231380, 2.537300, -0.862214, 0.568722,
|
|
-3.858362, 0.197867, -1.725885, 3.687312, -7.067363, 2.403544, -0.944963,
|
|
0.235639, -3.250094, 0.659117, -1.459576, 0.426128, -3.637207, 1.030386,
|
|
-4.224351, 3.516220, -6.053367, 0.993473, -2.182416, -0.762625, -1.884405,
|
|
-0.113736, -2.572602, 0.329290, -1.913233, 0.517418, -0.019757, 0.203176,
|
|
-3.715881, 0.482136, -1.912823, 1.357907, -5.473043, 1.714658, -3.177160,
|
|
0.089285, -3.127669, 1.268076, 0.772498, -1.622712, -3.850314, 0.436124,
|
|
-1.495983, 3.439982, -7.623405, 1.726721, -0.423979, 0.180201, -2.902406,
|
|
0.986457, -1.845638, 0.460903, -5.359343, -1.133931, -1.074456, 0.717304,
|
|
-3.519856, 1.012126, -0.562301, 1.881967, -6.716627, 2.525036, 0.945480,
|
|
0.337081, -5.210562, 2.572035, -0.943370, 0.442026, -2.666313, 0.411296,
|
|
0.002787, -0.000735, -2.498933, 0.771719, -3.568153, 3.833721, -6.617026,
|
|
2.813922, -0.573970, 1.025208, -3.909923, 1.722648, -1.406849, 0.719783,
|
|
-5.207438, 1.819442, -0.530895, -0.010887, -2.939614, 0.971225, -1.660297,
|
|
1.345243, -4.454571, 2.244876, -2.021213, 1.756090, -4.880947, 0.364597,
|
|
-2.380270, 2.763117, -5.613013, 2.137534, 0.289101, -2.279400, -3.365582,
|
|
0.170028, -1.142254, -0.709604, -3.656223, 1.804870, -0.854690, 0.592102,
|
|
-5.010415, 2.462687, -1.474710, 0.566002, -3.621819, -0.391946, -0.423524,
|
|
-0.631428, -3.513310, 0.962825, -1.480262, 0.319791, -3.610137, 1.842339,
|
|
-0.250073, 1.182022, -6.249267, 1.604172, 1.153759, -0.734054, -4.620415,
|
|
-0.030858, 0.050911, 1.524406, -4.724010, 1.451846, -3.277104, 2.414182,
|
|
-4.605285, 1.846092, -1.503047, -0.618200, -2.746546, -0.459332, -0.980326,
|
|
-1.199977, -2.043865, -0.165793, -2.214698, 3.108281, -7.127830, -0.123065,
|
|
1.244948, -3.039923, -4.660061, -0.225957, -0.307210, -1.513205, -2.456005,
|
|
0.840048, -0.741445, 2.328635, -6.015267, 2.723240, -1.381171, -0.728878,
|
|
-5.114925, -0.362034, -0.574923, 0.518080, -3.892457, 1.798948, 0.435119,
|
|
-0.371696, -2.807571, 1.302864, -2.063052, 1.036388, -4.232038, 1.397059,
|
|
-1.615668, -1.511019, -3.095508, 1.290955, -3.428723, 2.000287, -4.196487,
|
|
1.566983, 0.196957, 0.224343, -4.926359, -0.691975, -0.214941, 1.546821,
|
|
-5.384868, 2.290820, -1.878865, 0.493692, -4.129823, 2.112036, 0.516558,
|
|
-2.553077, -2.717338, 0.017146, -2.016057, 1.628995, -4.240602, 1.189533,
|
|
-5.460220, 1.254738, -4.214903, 0.755659, -2.893235, 2.937762, -6.169453,
|
|
2.035456, -5.613212, -0.122254, -1.973646, -0.060619, -2.119598, 1.413512,
|
|
-4.938738, 1.890244, 0.544169, -2.062413, -3.329637, -0.062515, -1.855805,
|
|
-0.791297, -2.570353, 0.607615, 0.305812, 0.338930, -4.150270, 2.274937,
|
|
0.042653, 0.133825, -3.538155, 1.523639, -3.173690, -1.496599, -2.414655,
|
|
0.464687, -1.448998, -0.368907, -3.520129, 0.203382, -2.443626, 1.266233,
|
|
-3.393848, 0.605911, -0.015353, 1.402006, -4.441003, 1.419281, 0.603587,
|
|
0.434146, -4.966566, 2.171872, -0.688264, -0.009981, -4.461103, 1.538354,
|
|
-5.029816, -0.264424, -1.713510, -0.315258, -1.891606, 0.252074, -2.419428,
|
|
0.043970, -1.291143, 2.048704, -4.590105, 0.524734, -1.889576, 0.134836,
|
|
-3.462745, 1.390663, -0.112773, 0.402735, -4.203784, 1.381043, -1.201634,
|
|
-1.968277, -1.425637, -0.181725, -1.250742, -2.102041, -3.925464, -1.256797,
|
|
-3.701354, -1.754610, -1.917231, -1.455910, -1.838006, 2.041781, -5.666212,
|
|
2.752957, -2.659553, 2.553637, -4.872212, 1.443437, -2.081846, 3.311263,
|
|
-5.912457, 1.871049, 0.196148, -0.307044, -4.024967, 2.149149, 0.361809,
|
|
0.620415, -5.939984, 0.180672, -1.209180, -0.269122, -3.240285, 1.460315,
|
|
-1.040803, 1.125700, -6.060366, 0.887767, -3.214111, 1.314368, -3.026808,
|
|
1.023640, -3.815175, 1.795642, -4.355603, 1.064454, -0.046472, 0.618463,
|
|
-5.941646, 2.861891, -2.852155, -0.990457, -2.624445, 1.794494, -1.176747,
|
|
-0.358159, -3.206776, 1.138721, -2.819523, -1.825522, -1.450902, -0.187312,
|
|
-0.808727, 0.636872, -4.120567, 1.192623, 0.810731, -1.768519, -3.699450,
|
|
1.527116, -2.772720, 3.012835, -5.912736, 1.599365, -4.696381, 2.234591,
|
|
-4.139552, 1.061768, -1.880089, 3.596274, -7.006379, 2.382152, -3.158115,
|
|
3.844430, -7.044156, 2.307596, -2.473970, 1.312644, -5.467269, 0.197154,
|
|
-1.530040, 1.762275, -5.550757, 0.630276, -3.048947, 1.043777, -3.096658,
|
|
1.345893, -1.329494, 2.065748, -4.711032, 2.227600, -0.413321, -0.032428,
|
|
-4.599650, 1.668734, -4.351490, -0.200022, -2.359903, 0.021997, 0.116028,
|
|
1.159718, -5.093972, -0.142951, -2.409895, 0.906133, -2.728812, 0.809932,
|
|
-2.597363, 0.494130, -2.357861, 0.369825, -2.165235, 1.148522, -3.130562,
|
|
0.759034, 0.646335, -1.463660, -3.508299, 1.059679, -1.485465, 1.007319,
|
|
-4.340716, 1.789864, -1.590654, 1.612324, -4.452007, 2.389805, -5.200148,
|
|
-1.068398, -1.306923, -0.472408, -0.392165, -0.524996, -2.933478, 1.518430,
|
|
-1.287781, 0.113422, -3.020525, 1.338359, -0.105982, 0.936014, -4.132197,
|
|
1.836807, -0.616589, -1.029716, -3.271347, 0.284889, -2.653359, 2.135829,
|
|
-4.643613, 1.627981, 0.287733, -2.017263, -2.776574, 1.184792, 1.004161,
|
|
-1.483019, -4.339290, -0.787322, 0.582420, 1.137839, -5.673941, -0.001862,
|
|
-1.219142, 0.532561, -4.457245, 1.826807, -3.343291, 3.034610, -6.179855,
|
|
2.235917, -4.369989, 4.018128, -6.632714, 0.926585, -0.485469, 0.536073,
|
|
-4.179557, 1.489637, -0.521762, 1.636089, -6.137912, 1.500867, -4.086009,
|
|
1.961372, -3.688977, 1.358220, -1.544034, 1.763837, -4.357567, 1.852201,
|
|
-2.018725, 1.046264, -6.211127, 1.609419, -0.118441, 1.602284, -6.242423,
|
|
1.518578, -0.604078, 1.106613, -5.393445, 2.595629, 0.142712, -1.903953,
|
|
-2.821177, 0.032758, -0.009152, 0.184628, -4.227636, 2.046843, -2.240138,
|
|
1.256176, -5.108516, -0.308447, -2.998571, 4.657396, -7.582112, 2.510951,
|
|
-3.535784, 1.704560, -5.068484, 1.318466, -3.058265, 3.073172, -6.998089,
|
|
3.178849, -2.420286, 2.277806, -4.999528, 1.423890, -1.672914, 0.447460,
|
|
-4.088940, 1.351087, -1.051546, -0.417955, -4.042147, 1.604102, -1.700931,
|
|
2.796663, -6.497579, 2.857974, -0.240828, 0.858001, -5.778933, 2.778508,
|
|
-0.406211, 1.300766, -5.073671, 2.089362, -0.201673, 1.588396, -6.000150,
|
|
2.185055, -2.332125, 0.768216, -2.609184, 0.327277, -3.358943, -1.020736,
|
|
-2.389984, 0.315512, -0.561905, 1.948740, -6.408485, 2.231985, -0.603652,
|
|
0.661829, -5.070386, -1.063058, -0.624796, 1.375772, -4.379606, 1.929358,
|
|
-1.047263, 0.739100, -5.217857, 2.127625, -5.025338, 0.650344, -2.068460,
|
|
0.076936, -0.457505, -1.050984, -1.917765, 1.150908, 0.782625, 0.855595,
|
|
-5.321719, 0.787209, -0.460232, 1.106736, -5.552326, 2.801043, -0.360217,
|
|
-0.434432, -4.273378, 0.967556, -0.972652, 0.874811, -5.429918, -0.331039,
|
|
0.115477, 0.111883, -5.418786, 1.240546, -1.842794, 0.505880, -3.676064,
|
|
-0.682369, 1.858984, -0.742566, -5.784060, 0.673239, -1.280398, 0.280842,
|
|
-4.848077, 2.214860, -0.785100, -0.588488, -2.438206, 0.786651, -1.568752,
|
|
1.935400, -6.320256, 2.125338, -1.476457, -1.651941, -2.695734, 0.007338,
|
|
-3.280860, 2.310385, -5.319578, 1.890123, -0.775723, 0.630606, -4.321582,
|
|
1.085521, -1.847371, 1.188521, -4.596577, 2.056443, -2.340172, -0.108501,
|
|
-3.156392, 0.933279, -0.495331, 0.122405, -5.171133, 1.763245, -0.796913,
|
|
2.310487, -7.247197, 2.401678, -1.908860, 0.043798, -2.393796, 0.573806,
|
|
-0.608531, 0.154710, -4.669001, 0.750680, 0.468380, 0.392591, -4.755001,
|
|
2.615217, -1.957774, 1.153513, -4.530099, 1.124362, -3.569415, 1.697154,
|
|
-3.536335, 0.910758, -2.976264, 1.833129, -4.287203, -0.547050, -2.409768,
|
|
0.061585, -1.324116, 0.268497, -2.962222, -1.524245, -2.063413, 0.442058,
|
|
-4.292337, 3.538863, -6.699603, 1.718664, -2.290363, 1.994596, -6.245037,
|
|
-0.433084, -0.367059, 1.020297, -4.940721, 2.902264, -0.577056, -0.709887,
|
|
-5.001413, -0.268316, -1.112048, -1.083307, -1.753492, 0.209973, 0.139540,
|
|
0.917602, -5.232745, 2.538467, -2.139234, -0.187388, -1.837249, -0.478582,
|
|
-0.731653, -0.481550, -2.531261, 1.044770, 0.707750, 0.279971, -3.221119,
|
|
1.552074, -2.373144, 0.859518, -3.665156, 1.620278, -1.440871, -0.525581,
|
|
-2.758271, 1.491873, -2.302013, 1.119935, -5.257080, 2.627170, -3.174739,
|
|
1.363282, -4.831639, 1.101076, -4.337008, 2.689639, -5.165915, 1.069201,
|
|
-1.882078, -0.120370, -2.287967, 1.147619, -1.403616, 1.077150, -5.084296,
|
|
1.658236, -0.919642, 0.487423, -3.001075, 0.741268, 0.107300, 0.943556,
|
|
-3.544311, 1.000239, -1.627171, 2.871253, -5.179172, 1.429893, -0.826040,
|
|
0.188670, -4.499894, 1.013447, -2.101299, 0.317516, -3.452141, -0.833776,
|
|
-1.362144, 1.272437, -4.449355, 1.613591, -2.039873, 2.613175, -6.229640,
|
|
1.659790, -1.595520, -0.237462, -2.744997, 0.337841, 0.148981, -1.703771,
|
|
-2.388023, 1.276469, 1.058508, -0.401642, -4.680769, 0.861881, -1.336381,
|
|
1.153080, -2.834378, 0.721075, 0.900115, 1.360511, -5.573611, 0.949182,
|
|
-2.970844, 2.017563, -5.186108, -0.201038, -1.192824, 0.610142, -4.450919,
|
|
-0.897114, -1.812093, 0.422310, -5.245487, 0.256549, 0.320275, -2.324150,
|
|
-2.967040, -0.260536, -0.721467, 0.454148, -5.058031, 0.526370, -0.895656,
|
|
0.732240, -3.327363, 1.353953, -1.277912, -0.483171, -1.926713, 0.065044,
|
|
-2.167506, -0.196606, -1.923437, 0.604962, -2.088319, 1.406834, -5.227296,
|
|
2.247351, -4.421744, 1.729791, -5.007922, 1.264769, -0.897019, 0.922902,
|
|
-3.887108, 2.087432, -1.310226, -0.101938, -3.359082, -0.079662, -0.514988,
|
|
-0.963179, -4.038209, 2.223278, -0.590083, -2.310458, -1.748338, 0.363406,
|
|
-0.540731, -0.885913, -4.179595, 2.216781, -3.044339, -0.447100, -2.446098,
|
|
0.931101, -1.676190, 2.096175, -4.980755, 2.262151, -1.095047, 1.897516,
|
|
-5.996138, 2.191038, 0.297128, -0.780974, -2.884299, 1.195408, -0.521065,
|
|
-1.955837, -3.091064, -0.404183, -1.961519, 4.076096, -7.521851, 2.242064,
|
|
-1.988043, 0.303300, -2.422585, 0.322230, -3.377634, 3.499955, -7.084434,
|
|
2.375587, -0.718851, 2.150076, -5.412241, 2.374280, -2.006088, 2.229828,
|
|
-5.848188, 2.543077, -2.171042, 2.096026, -5.300007, 0.141405, -1.187745,
|
|
0.105340, -4.003816, 1.034281, -3.980804, 1.856709, -5.103042, 0.623737,
|
|
-2.080307, 0.896140, -3.104050, 0.983158, -0.424898, -1.154270, -3.805728,
|
|
1.978917, -1.314387, 1.235096, -3.148906, 1.113173, 0.111713, 2.055213,
|
|
-7.565283, 2.100342};
|
|
const std::initializer_list<float> biases = {
|
|
0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
|
|
-0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178,
|
|
0.37197268, 0.61957061, 0.3956964, -0.37609905};
|
|
|
|
const std::initializer_list<float> recurrent_weights = {
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.1};
|
|
|
|
class BidirectionalRNNOpModel : public SingleOpModel {
|
|
public:
|
|
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
|
|
int bw_units, int input_size, int aux_input_size,
|
|
AuxInputMode aux_input_mode, bool time_major,
|
|
bool merge_outputs, bool quantize_weights = false,
|
|
bool asymmetric_quantize_weights = false)
|
|
: batches_(batches),
|
|
sequence_len_(sequence_len),
|
|
fw_units_(fw_units),
|
|
bw_units_(bw_units),
|
|
input_size_(input_size),
|
|
aux_input_size_(aux_input_size),
|
|
quantize_weights_(quantize_weights) {
|
|
const TensorType tensor_type =
|
|
quantize_weights ? TensorType_UINT8 : TensorType_FLOAT32;
|
|
input_ = AddInput(TensorType_FLOAT32);
|
|
fw_weights_ = AddInput(tensor_type);
|
|
fw_recurrent_weights_ = AddInput(tensor_type);
|
|
fw_bias_ = AddInput(TensorType_FLOAT32);
|
|
fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
|
bw_weights_ = AddInput(tensor_type);
|
|
bw_recurrent_weights_ = AddInput(tensor_type);
|
|
bw_bias_ = AddInput(TensorType_FLOAT32);
|
|
bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
|
|
|
const auto input_shape =
|
|
(time_major) ? std::vector<int>({sequence_len_, batches_, input_size_})
|
|
: std::vector<int>({batches_, sequence_len_, input_size_});
|
|
|
|
std::vector<int> aux_input_shape = {0};
|
|
std::vector<int> aux_fw_weights_shape = {0};
|
|
std::vector<int> aux_bw_weights_shape = {0};
|
|
if (aux_input_mode != AuxInputMode::kNoAuxInput) {
|
|
aux_input_ = AddInput(TensorType_FLOAT32);
|
|
aux_input_shape =
|
|
(time_major)
|
|
? std::vector<int>({sequence_len_, batches_, aux_input_size_})
|
|
: std::vector<int>({batches_, sequence_len_, aux_input_size_});
|
|
} else {
|
|
aux_input_ = AddNullInput();
|
|
}
|
|
|
|
if (aux_input_mode == AuxInputMode::kCrossLinking) {
|
|
aux_fw_weights_ = AddInput(tensor_type);
|
|
aux_bw_weights_ = AddInput(tensor_type);
|
|
|
|
aux_fw_weights_shape = {fw_units, aux_input_size_};
|
|
aux_bw_weights_shape = {bw_units, aux_input_size_};
|
|
} else {
|
|
aux_fw_weights_ = AddNullInput();
|
|
aux_bw_weights_ = AddNullInput();
|
|
}
|
|
|
|
fw_output_ = AddOutput(TensorType_FLOAT32);
|
|
if (!merge_outputs) {
|
|
bw_output_ = AddOutput(TensorType_FLOAT32);
|
|
}
|
|
|
|
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
|
|
BuiltinOptions_BidirectionalSequenceRNNOptions,
|
|
CreateBidirectionalSequenceRNNOptions(
|
|
builder_, time_major, ActivationFunctionType_RELU,
|
|
merge_outputs, asymmetric_quantize_weights)
|
|
.Union());
|
|
|
|
BuildInterpreter({
|
|
input_shape, // input
|
|
{fw_units_, input_size_}, // fw_weights
|
|
{fw_units_, fw_units_}, // fw_recurrent_weights
|
|
{fw_units_}, // fw_bias
|
|
{batches_, fw_units_}, // fw_hidden_state
|
|
{bw_units_, input_size_}, // bw_weights
|
|
{bw_units_, bw_units_}, // bw_recurrent_weights
|
|
{bw_units_}, // bw_bias
|
|
{batches_, bw_units_}, // bw_hidden_state
|
|
aux_input_shape, // aux_input
|
|
aux_fw_weights_shape, // aux_fw_weights
|
|
aux_bw_weights_shape, // aux_bw_weights
|
|
});
|
|
}
|
|
|
|
void SetFwBias(std::initializer_list<float> f) {
|
|
PopulateTensor(fw_bias_, f);
|
|
}
|
|
|
|
void SetBwBias(std::initializer_list<float> f) {
|
|
PopulateTensor(bw_bias_, f);
|
|
}
|
|
|
|
void SetFwWeights(const std::vector<float>& f) {
|
|
if (quantize_weights_) {
|
|
SymmetricQuantizeAndPopulate(fw_weights_, f);
|
|
} else {
|
|
PopulateTensor(fw_weights_, f);
|
|
}
|
|
}
|
|
|
|
void SetBwWeights(const std::vector<float>& f) {
|
|
if (quantize_weights_) {
|
|
SymmetricQuantizeAndPopulate(bw_weights_, f);
|
|
} else {
|
|
PopulateTensor(bw_weights_, f);
|
|
}
|
|
}
|
|
|
|
void SetFwRecurrentWeights(const std::vector<float>& f) {
|
|
if (quantize_weights_) {
|
|
SymmetricQuantizeAndPopulate(fw_recurrent_weights_, f);
|
|
} else {
|
|
PopulateTensor(fw_recurrent_weights_, f);
|
|
}
|
|
}
|
|
|
|
void SetBwRecurrentWeights(const std::vector<float>& f) {
|
|
if (quantize_weights_) {
|
|
SymmetricQuantizeAndPopulate(bw_recurrent_weights_, f);
|
|
} else {
|
|
PopulateTensor(bw_recurrent_weights_, f);
|
|
}
|
|
}
|
|
|
|
void SetInput(std::initializer_list<float> data) {
|
|
PopulateTensor(input_, data);
|
|
}
|
|
|
|
void SetInput(int offset, float* begin, float* end) {
|
|
PopulateTensor(input_, offset, begin, end);
|
|
}
|
|
|
|
void SetAuxInput(int offset, float* begin, float* end) {
|
|
PopulateTensor(aux_input_, offset, begin, end);
|
|
}
|
|
|
|
void SetAuxFwWeights(const std::vector<float>& f) {
|
|
if (quantize_weights_) {
|
|
SymmetricQuantizeAndPopulate(aux_fw_weights_, f);
|
|
} else {
|
|
PopulateTensor(aux_fw_weights_, f);
|
|
}
|
|
}
|
|
|
|
void SetAuxBwWeights(const std::vector<float>& f) {
|
|
if (quantize_weights_) {
|
|
SymmetricQuantizeAndPopulate(aux_bw_weights_, f);
|
|
} else {
|
|
PopulateTensor(aux_bw_weights_, f);
|
|
}
|
|
}
|
|
|
|
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
|
|
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
|
|
|
|
int input_size() { return input_size_; }
|
|
int aux_input_size() { return aux_input_size_; }
|
|
int num_fw_units() { return fw_units_; }
|
|
int num_bw_units() { return bw_units_; }
|
|
int num_batches() { return batches_; }
|
|
int sequence_len() { return sequence_len_; }
|
|
|
|
private:
|
|
int input_;
|
|
int fw_weights_;
|
|
int fw_recurrent_weights_;
|
|
int fw_bias_;
|
|
int fw_hidden_state_;
|
|
int fw_output_;
|
|
int bw_weights_;
|
|
int bw_recurrent_weights_;
|
|
int bw_bias_;
|
|
int bw_hidden_state_;
|
|
int bw_output_;
|
|
int aux_input_;
|
|
int aux_fw_weights_;
|
|
int aux_bw_weights_;
|
|
|
|
int batches_;
|
|
int sequence_len_;
|
|
int fw_units_;
|
|
int bw_units_;
|
|
int input_size_;
|
|
int aux_input_size_;
|
|
bool quantize_weights_;
|
|
};
|
|
|
|
// Declare LSTMOpTest as a parameterized test.
|
|
class BidirectionalRNNOpTest
|
|
: public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, BidirectionalRNNOpTest,
|
|
::testing::Combine(
|
|
/*quantize_weights*/ ::testing::Bool(),
|
|
/*asymmetric_quantize_inputs*/ ::testing::Bool()));
|
|
|
|
// TODO(mirkov): add another test which directly compares to TF once TOCO
|
|
// supports the conversion from dynamic_rnn with BasicRNNCell.
|
|
TEST_P(BidirectionalRNNOpTest, BlackBoxTest) {
|
|
auto params = GetParam();
|
|
const bool quantize_weights = std::get<0>(params);
|
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/0,
|
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
|
/*time_major=*/false,
|
|
/*merge_outputs=*/false, quantize_weights,
|
|
asymmetric_quantize_inputs);
|
|
rnn.SetFwWeights(weights);
|
|
rnn.SetBwWeights(weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
|
|
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
|
|
float* batch_start = rnn_input;
|
|
float* batch_end = batch_start + input_sequence_size;
|
|
rnn.SetInput(0, batch_start, batch_end);
|
|
rnn.SetInput(input_sequence_size, batch_start, batch_end);
|
|
|
|
rnn.Invoke();
|
|
|
|
float* golden_fw_start = rnn_golden_fw_output;
|
|
float* golden_fw_end =
|
|
golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
|
|
std::vector<float> fw_expected;
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
EXPECT_THAT(rnn.GetFwOutput(),
|
|
ElementsAreArray(ArrayFloatNear(
|
|
fw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
|
|
|
|
float* golden_bw_start = rnn_golden_bw_output;
|
|
float* golden_bw_end =
|
|
golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
|
|
std::vector<float> bw_expected;
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
EXPECT_THAT(rnn.GetBwOutput(),
|
|
ElementsAreArray(ArrayFloatNear(
|
|
bw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
|
|
}
|
|
|
|
// Same as BlackBox test, but input is reshuffled to time_major format.
|
|
TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
|
|
auto params = GetParam();
|
|
const bool quantize_weights = std::get<0>(params);
|
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/0,
|
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
|
/*time_major=*/true,
|
|
/*merge_outputs=*/false, quantize_weights,
|
|
asymmetric_quantize_inputs);
|
|
rnn.SetFwWeights(weights);
|
|
rnn.SetBwWeights(weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
|
|
// Insert the inputs in time_major format. The batch_major format is:
|
|
// [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
|
|
// [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* batch_start = rnn_input + i * rnn.input_size();
|
|
float* batch_end = batch_start + rnn.input_size();
|
|
// The two batches are identical.
|
|
rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
|
|
}
|
|
|
|
rnn.Invoke();
|
|
|
|
std::vector<float> fw_expected;
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
|
|
float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
}
|
|
constexpr float kHybridTolerance = 3.57e-1;
|
|
constexpr float kFloatTolerance = 1e-5;
|
|
EXPECT_THAT(
|
|
rnn.GetFwOutput(),
|
|
ElementsAreArray(ArrayFloatNear(
|
|
fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance)));
|
|
}
|
|
|
|
// Same as BlackBox test, yet with merged outputs.
|
|
TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
|
|
auto params = GetParam();
|
|
const bool quantize_weights = std::get<0>(params);
|
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/0,
|
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
|
/*time_major=*/false,
|
|
/*merge_outputs=*/true, quantize_weights,
|
|
asymmetric_quantize_inputs);
|
|
rnn.SetFwWeights(weights);
|
|
rnn.SetBwWeights(weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
|
|
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
|
|
float* batch_start = rnn_input;
|
|
float* batch_end = batch_start + input_sequence_size;
|
|
rnn.SetInput(0, batch_start, batch_end);
|
|
rnn.SetInput(input_sequence_size, batch_start, batch_end);
|
|
|
|
rnn.Invoke();
|
|
|
|
std::vector<float> merged_expected;
|
|
for (int bid = 0; bid < rnn.num_batches(); bid++) {
|
|
for (int step = 0; step < rnn.sequence_len(); step++) {
|
|
merged_expected.insert(
|
|
merged_expected.end(),
|
|
rnn_golden_fw_output + rnn.num_fw_units() * step,
|
|
rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
|
|
merged_expected.insert(
|
|
merged_expected.end(),
|
|
rnn_golden_bw_output + rnn.num_bw_units() * step,
|
|
rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
|
|
}
|
|
}
|
|
EXPECT_THAT(rnn.GetFwOutput(),
|
|
ElementsAreArray(ArrayFloatNear(
|
|
merged_expected, quantize_weights ? 1.42e-2 : 1e-5)));
|
|
}
|
|
|
|
// Same as BlackBox test, but input is reshuffled to time_major format.
|
|
TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajorMergeOutputs) {
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/0,
|
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
|
/*time_major=*/true,
|
|
/*merge_outputs=*/true);
|
|
rnn.SetFwWeights(weights);
|
|
rnn.SetBwWeights(weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
|
|
// Insert the inputs in time_major format. The batch_major format is:
|
|
// [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
|
|
// [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* batch_start = rnn_input + i * rnn.input_size();
|
|
float* batch_end = batch_start + rnn.input_size();
|
|
// The two batches are identical.
|
|
rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
|
|
}
|
|
|
|
rnn.Invoke();
|
|
|
|
std::vector<float> merged_expected;
|
|
for (int step = 0; step < rnn.sequence_len(); step++) {
|
|
for (int bid = 0; bid < rnn.num_batches(); bid++) {
|
|
merged_expected.insert(
|
|
merged_expected.end(),
|
|
rnn_golden_fw_output + rnn.num_fw_units() * step,
|
|
rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
|
|
merged_expected.insert(
|
|
merged_expected.end(),
|
|
rnn_golden_bw_output + rnn.num_bw_units() * step,
|
|
rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
|
|
}
|
|
}
|
|
EXPECT_THAT(rnn.GetFwOutput(),
|
|
ElementsAreArray(ArrayFloatNear(merged_expected)));
|
|
}
|
|
|
|
// Check that if the input sequence is reversed the outputs are the same just
|
|
// forward and backward are swapped (and reversed).
|
|
TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/0,
|
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
|
/*time_major=*/false,
|
|
/*merge_outputs=*/false);
|
|
rnn.SetFwWeights(weights);
|
|
rnn.SetBwWeights(weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
|
|
// Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the
|
|
// following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1].
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* batch_start = rnn_input + i * rnn.input_size();
|
|
float* batch_end = batch_start + rnn.input_size();
|
|
const int reverse_idx = rnn.sequence_len() - i - 1;
|
|
rnn.SetInput(reverse_idx * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetInput((rnn.sequence_len() + reverse_idx) * rnn.input_size(),
|
|
batch_start, batch_end);
|
|
}
|
|
|
|
rnn.Invoke();
|
|
|
|
// The forward and backward outputs are swapped.
|
|
std::vector<float> fw_expected; // consider using std::deque instead.
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* golden_fw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
|
|
float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
|
|
fw_expected.insert(fw_expected.begin(), golden_fw_start, golden_fw_end);
|
|
}
|
|
fw_expected.insert(fw_expected.end(), fw_expected.begin(), fw_expected.end());
|
|
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
|
|
|
std::vector<float> bw_expected;
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* golden_bw_start = rnn_golden_fw_output + i * rnn.num_bw_units();
|
|
float* golden_bw_end = golden_bw_start + rnn.num_bw_units();
|
|
bw_expected.insert(bw_expected.begin(), golden_bw_start, golden_bw_end);
|
|
}
|
|
bw_expected.insert(bw_expected.end(), bw_expected.begin(), bw_expected.end());
|
|
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
|
|
}
|
|
|
|
// Tests an end-to-end neural network with a Bidirectional RNN followed by a
|
|
// DNN that aggregates the outputs from the two sequences.
|
|
TEST(BidirectionalRNNOpTest, EndToEndTest) {
|
|
BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/0,
|
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
|
/*time_major=*/false,
|
|
/*merge_outputs=*/false);
|
|
const int output_size = 4;
|
|
float dnn_weights[] = {
|
|
-0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139,
|
|
-0.23420811, -0.39647382, 0.31423986, 0.61819065, -0.73659575,
|
|
-0.89698344, -0.8931554, -0.0845688, 0.5617367, 0.38415289,
|
|
-0.11487955, -0.7617774, 0.17927337, 0.15726972, 0.059798479,
|
|
0.19009054, -0.27616632, -0.39142907, 0.77744663, -0.046830714,
|
|
-0.6603595, 0.21945822, 0.051494241, 0.23785079, 0.19239247,
|
|
-0.53268754, 0.65961659, -0.85981959, -0.80232513, 0.84745562,
|
|
-0.66070104, -0.036533296, -0.54901814, 0.65353882, -0.41834265,
|
|
-0.28561389, 0.75655544, -0.31149811, 0.62981737, 0.31829214,
|
|
-0.92734522, -0.48506218, 0.55651462, 0.25192821, 0.67220747,
|
|
-0.3836869, -0.55798125, -0.60395885, 0.22488403, -0.78053463,
|
|
0.3492105, 0.56452453, 0.4389236, -0.59929526, -0.19762468,
|
|
-0.36868393, -0.13198286, -0.53800809, -0.22850353};
|
|
|
|
std::initializer_list<float> dnn_biases = {0.29177809, -0.98799044,
|
|
0.065919638, 0.68781924};
|
|
|
|
rnn.SetFwWeights(weights);
|
|
rnn.SetBwWeights(weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
|
|
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
|
|
const int output_sequence_size = output_size * rnn.sequence_len();
|
|
const int num_examples = 64;
|
|
for (int k = 0; k < num_examples; k++) {
|
|
float* batch_start = endtoend_input + k * input_sequence_size;
|
|
float* batch_end = batch_start + input_sequence_size;
|
|
rnn.SetInput(0, batch_start, batch_end);
|
|
|
|
rnn.Invoke();
|
|
|
|
std::vector<float> fw_output = rnn.GetFwOutput();
|
|
std::vector<float> bw_output = rnn.GetBwOutput();
|
|
EXPECT_EQ(fw_output.size(), bw_output.size());
|
|
|
|
std::transform(fw_output.begin(), fw_output.end(), bw_output.begin(),
|
|
fw_output.begin(), std::plus<float>());
|
|
|
|
std::vector<float> sequence_result;
|
|
for (int s = 0; s < rnn.sequence_len(); s++) {
|
|
const float* rnn_output = fw_output.data() + s * rnn.num_fw_units();
|
|
std::vector<float> results(dnn_biases);
|
|
for (int i = 0; i < output_size; i++) {
|
|
for (int j = 0; j < rnn.num_fw_units(); j++) {
|
|
results[i] += *(rnn_output + j) * dnn_weights[output_size * j + i];
|
|
}
|
|
}
|
|
sequence_result.insert(sequence_result.end(), results.begin(),
|
|
results.end());
|
|
}
|
|
|
|
float* golden_start = golden_endtoend_output + k * output_sequence_size;
|
|
float* golden_end = golden_start + output_sequence_size;
|
|
|
|
std::vector<float> expected;
|
|
expected.insert(expected.end(), golden_start, golden_end);
|
|
EXPECT_THAT(sequence_result, ElementsAreArray(ArrayFloatNear(expected)));
|
|
}
|
|
}
|
|
|
|
// Same as BlackBox test, but has an auxiliary input. The layer has no
|
|
// cross-linking, i.e. the regular input is passed as an input to the forward
|
|
// network only and the auxiliary input is passed as an input to the backward
|
|
// network only.
|
|
TEST(BidirectionalRNNOpTest, BlackBoxTestNoCrossLinkingRegularAndAuxInput) {
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/8,
|
|
/*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
|
|
/*time_major=*/true,
|
|
/*merge_outputs=*/false);
|
|
rnn.SetFwWeights(weights);
|
|
rnn.SetBwWeights(weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
|
|
// Insert the inputs in time_major format. The batch_major format is:
|
|
// [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
|
|
// [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* batch_start = rnn_input + i * rnn.input_size();
|
|
float* batch_end = batch_start + rnn.input_size();
|
|
// The two batches are identical.
|
|
// Also make aux input the same as input.
|
|
rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
|
|
}
|
|
|
|
rnn.Invoke();
|
|
|
|
std::vector<float> fw_expected;
|
|
std::vector<float> bw_expected;
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
|
|
float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
|
|
float* golden_bw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
|
|
float* golden_bw_end = golden_bw_start + rnn.num_fw_units();
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
}
|
|
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
|
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
|
|
}
|
|
|
|
// Same as above but the auxiliary input is set to zeroes. This test makes sure
|
|
// that the forward network works as expected in a no-cross-linking mode.
|
|
TEST(BidirectionalRNNOpTest, BlackBoxTestNoCrossLinkingRegularInputOnly) {
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/8,
|
|
/*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
|
|
/*time_major=*/true,
|
|
/*merge_outputs=*/false);
|
|
rnn.SetFwWeights(weights);
|
|
rnn.SetBwWeights(weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
|
|
// Initialize bw inputs with zeros.
|
|
std::vector<float> bw_inputs(rnn.sequence_len(), 0);
|
|
|
|
// Insert the inputs in time_major format. The batch_major format is:
|
|
// [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
|
|
// [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* batch_start = rnn_input + i * rnn.input_size();
|
|
float* batch_end = batch_start + rnn.input_size();
|
|
// The two batches are identical.
|
|
// Also make aux input the same as input.
|
|
rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetAuxInput(2 * i * rnn.input_size(), &bw_inputs[0],
|
|
&bw_inputs[bw_inputs.size() - 1]);
|
|
rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), &bw_inputs[0],
|
|
&bw_inputs[bw_inputs.size() - 1]);
|
|
}
|
|
|
|
rnn.Invoke();
|
|
|
|
std::vector<float> fw_expected;
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
|
|
float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
}
|
|
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
|
}
|
|
|
|
// Same as above but the regular (i.e. not auxiliary) input is set to zeroes.
|
|
// This test makes sure that the backward network works as expected in a
|
|
// no-cross-linking mode.
|
|
TEST(BidirectionalRNNOpTest, BlackBoxTestNoCrossLinkingAuxInputOnly) {
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/8,
|
|
/*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
|
|
/*time_major=*/true,
|
|
/*merge_outputs=*/false);
|
|
rnn.SetFwWeights(weights);
|
|
rnn.SetBwWeights(weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
|
|
// Initialize bw inputs with zeros.
|
|
std::vector<float> fw_inputs(rnn.sequence_len(), 0);
|
|
|
|
// Insert the inputs in time_major format. The batch_major format is:
|
|
// [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
|
|
// [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* batch_start = rnn_input + i * rnn.input_size();
|
|
float* batch_end = batch_start + rnn.input_size();
|
|
// The two batches are identical.
|
|
// Also make aux input the same as input.
|
|
rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetInput(2 * i * rnn.input_size(), &fw_inputs[0],
|
|
&fw_inputs[fw_inputs.size() - 1]);
|
|
rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
|
|
rnn.SetInput((2 * i + 1) * rnn.input_size(), &fw_inputs[0],
|
|
&fw_inputs[fw_inputs.size() - 1]);
|
|
}
|
|
|
|
rnn.Invoke();
|
|
|
|
std::vector<float> bw_expected;
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* golden_bw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
|
|
float* golden_bw_end = golden_bw_start + rnn.num_fw_units();
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
}
|
|
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
|
|
}
|
|
|
|
// Same as BlackBox test, but an input is passed to auxiliary input instead of
|
|
// the regular one. Regular input and weights are set to zero.
|
|
TEST(BidirectionalRNNOpTest, BlackBoxTestCrossLinkingAuxInputOnly) {
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/8,
|
|
/*aux_input_mode=*/AuxInputMode::kCrossLinking,
|
|
/*time_major=*/false,
|
|
/*merge_outputs=*/false);
|
|
rnn.SetFwWeights(std::vector<float>(weights.size(), 0.0));
|
|
rnn.SetBwWeights(std::vector<float>(weights.size(), 0.0));
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
rnn.SetAuxFwWeights(weights);
|
|
rnn.SetAuxBwWeights(weights);
|
|
|
|
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
|
|
std::vector<float> zero_input(input_sequence_size, 0.f);
|
|
float* batch_start = rnn_input;
|
|
float* batch_end = batch_start + input_sequence_size;
|
|
// Set batch 0 inputs
|
|
rnn.SetInput(0, zero_input.data(), zero_input.data() + zero_input.size());
|
|
rnn.SetAuxInput(0, batch_start, batch_end);
|
|
// Set batch 1 inputs
|
|
rnn.SetInput(input_sequence_size, zero_input.data(),
|
|
zero_input.data() + zero_input.size());
|
|
rnn.SetAuxInput(input_sequence_size, batch_start, batch_end);
|
|
|
|
rnn.Invoke();
|
|
|
|
float* golden_fw_start = rnn_golden_fw_output;
|
|
float* golden_fw_end =
|
|
golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
|
|
std::vector<float> fw_expected;
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
|
|
|
float* golden_bw_start = rnn_golden_bw_output;
|
|
float* golden_bw_end =
|
|
golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
|
|
std::vector<float> bw_expected;
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
|
|
}
|
|
|
|
// Same as BlackBox test, but an input is passed to auxiliary input instead of
|
|
// the regular one. Regular input and weights are set to zero. Time major inputs
|
|
// and outputs.
|
|
TEST(BidirectionalRNNOpTest, BlackBoxTestCrossLinkingAuxInputOnlyTimeMajor) {
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/16, /*bw_units=*/16,
|
|
/*input_size=*/8, /*aux_input_size=*/8,
|
|
/*aux_input_mode=*/AuxInputMode::kCrossLinking,
|
|
/*time_major=*/true,
|
|
/*merge_outputs=*/false);
|
|
rnn.SetFwWeights(std::vector<float>(weights.size(), 0.0));
|
|
rnn.SetBwWeights(std::vector<float>(weights.size(), 0.0));
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
rnn.SetAuxFwWeights(weights);
|
|
rnn.SetAuxBwWeights(weights);
|
|
|
|
std::vector<float> zero_input(rnn.sequence_len(), 0.f);
|
|
|
|
// Insert the inputs in time_major format. The batch_major format is:
|
|
// [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
|
|
// [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* batch_start = rnn_input + i * rnn.input_size();
|
|
float* batch_end = batch_start + rnn.input_size();
|
|
// The two batches are identical.
|
|
// Set batch 0 inputs
|
|
rnn.SetInput(2 * i * rnn.input_size(), &zero_input.front(),
|
|
&zero_input.back() + 1);
|
|
rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
|
|
// Set batch 1 inputs
|
|
rnn.SetInput((2 * i + 1) * rnn.input_size(), &zero_input.front(),
|
|
&zero_input.back() + 1);
|
|
rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
|
|
}
|
|
|
|
rnn.Invoke();
|
|
|
|
std::vector<float> fw_expected;
|
|
for (int i = 0; i < rnn.sequence_len(); i++) {
|
|
float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
|
|
float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
}
|
|
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
|
}
|
|
|
|
// Same as BlackBox test, but the input tensor and weights tensor are split
|
|
// along the last dimension and passed to both regular and auxiliary inputs and
|
|
// weights. The output in this case is the same. To understand this, let's
|
|
// define W and V as regular input weights matrix and auxiliary input weights
|
|
// matrix correspondingly. It's easy to see that this is equivalent to a regular
|
|
// RNN with weights U = (W|V) and z^T = x^T | y^T, where .|. denotes
|
|
// concatenation along horizontal axis:
|
|
// f(z) = Uz + b
|
|
// is equivalent to:
|
|
// f((x^T|y^T)^T) = (Wx + Vy) + b.
|
|
void run_blackbox_test_with_input_split(int input_size, int aux_input_size) {
|
|
const int num_units = 16;
|
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
|
/*fw_units=*/num_units, /*bw_units=*/num_units,
|
|
input_size, aux_input_size,
|
|
/*aux_input_mode=*/AuxInputMode::kCrossLinking,
|
|
/*time_major=*/false,
|
|
/*merge_outputs=*/false);
|
|
std::vector<float> reg_weights(num_units * rnn.input_size());
|
|
std::vector<float> aux_weights(num_units * rnn.aux_input_size());
|
|
int full_weights_size = weights.size();
|
|
int reg_weights_offset = 0;
|
|
int aux_weights_offset = 0;
|
|
int weights_offset = 0;
|
|
// Alternating copying to regular input weights and auxiliary input weights to
|
|
// split the original weight matrix in half along the last axis.
|
|
while (weights_offset < full_weights_size) {
|
|
std::copy(weights.begin() + weights_offset,
|
|
weights.begin() + weights_offset + rnn.input_size(),
|
|
reg_weights.begin() + reg_weights_offset);
|
|
weights_offset += rnn.input_size();
|
|
reg_weights_offset += rnn.input_size();
|
|
|
|
std::copy(weights.begin() + weights_offset,
|
|
weights.begin() + weights_offset + rnn.aux_input_size(),
|
|
aux_weights.begin() + aux_weights_offset);
|
|
weights_offset += rnn.aux_input_size();
|
|
aux_weights_offset += rnn.aux_input_size();
|
|
}
|
|
|
|
rnn.SetFwWeights(reg_weights);
|
|
rnn.SetBwWeights(reg_weights);
|
|
rnn.SetFwBias(biases);
|
|
rnn.SetBwBias(biases);
|
|
rnn.SetFwRecurrentWeights(recurrent_weights);
|
|
rnn.SetBwRecurrentWeights(recurrent_weights);
|
|
rnn.SetAuxFwWeights(aux_weights);
|
|
rnn.SetAuxBwWeights(aux_weights);
|
|
|
|
int full_input_size =
|
|
(rnn.input_size() + rnn.aux_input_size()) * rnn.sequence_len();
|
|
int reg_input_offset = 0;
|
|
int aux_input_offset = 0;
|
|
// Alternating copying to regular input tensor and auxiliary input tensor to
|
|
// split the original input matrix in half along the last axis.
|
|
for (int batch = 0; batch < 2; ++batch) {
|
|
int input_offset = 0;
|
|
while (input_offset < full_input_size) {
|
|
rnn.SetInput(reg_input_offset, rnn_input + input_offset,
|
|
rnn_input + input_offset + rnn.input_size());
|
|
input_offset += rnn.input_size();
|
|
reg_input_offset += rnn.input_size();
|
|
|
|
rnn.SetAuxInput(aux_input_offset, rnn_input + input_offset,
|
|
rnn_input + input_offset + rnn.aux_input_size());
|
|
input_offset += rnn.aux_input_size();
|
|
aux_input_offset += rnn.aux_input_size();
|
|
}
|
|
}
|
|
|
|
rnn.Invoke();
|
|
|
|
float* golden_fw_start = rnn_golden_fw_output;
|
|
float* golden_fw_end =
|
|
golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
|
|
std::vector<float> fw_expected;
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
|
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
|
|
|
float* golden_bw_start = rnn_golden_bw_output;
|
|
float* golden_bw_end =
|
|
golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
|
|
std::vector<float> bw_expected;
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
|
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
|
|
}
|
|
|
|
TEST(BidirectionalRNNOpTest,
|
|
BlackBoxTestCrossLinkingRegularAndAuxInputEvenSplit) {
|
|
run_blackbox_test_with_input_split(/*input_size=*/4, /*aux_input_size=*/4);
|
|
}
|
|
|
|
// Same as above but the input tensor and the weights tensor are split unevenly.
|
|
TEST(BidirectionalRNNOpTest,
|
|
BlackBoxTestCrossLinkingRegularAndAuxInputUnevenSplit) {
|
|
run_blackbox_test_with_input_split(/*input_size=*/2, /*aux_input_size=*/6);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace tflite
|