Use GetVariableInput (or GetInput in const context) for the activation tensor in SVDF, instead of accessing it through its index.
PiperOrigin-RevId: 294748701 Change-Id: I6208f0af2b2db8d294fef559eb650b9c66cc9b6d
This commit is contained in:
parent
b9b688ebd2
commit
24fa7c6822
@ -39,7 +39,6 @@ namespace {
|
||||
struct OpData {
|
||||
int scratch_tensor_index;
|
||||
bool float_weights_time_initialized;
|
||||
int activation_state_tensor_index;
|
||||
int32 effective_scale_1_a;
|
||||
int effective_scale_1_b;
|
||||
int32 effective_scale_2_a;
|
||||
@ -80,8 +79,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Check we have all the inputs and outputs we need.
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
|
||||
op_data->activation_state_tensor_index =
|
||||
node->inputs->data[kInputActivationStateTensor];
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* weights_feature =
|
||||
@ -109,8 +106,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
|
||||
}
|
||||
|
||||
TfLiteTensor* activation_state =
|
||||
&context->tensors[op_data->activation_state_tensor_index];
|
||||
const TfLiteTensor* activation_state =
|
||||
GetInput(context, node, kInputActivationStateTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
// Check the shape of input state tensors.
|
||||
@ -250,7 +247,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
|
||||
|
||||
TfLiteTensor* activation_state =
|
||||
&context->tensors[op_data->activation_state_tensor_index];
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
switch (weights_feature->type) {
|
||||
|
@ -377,8 +377,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* weights_time =
|
||||
GetInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* activation_state =
|
||||
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
|
||||
const TfLiteTensor* activation_state =
|
||||
GetInput(context, node, kInputActivationStateTensor);
|
||||
|
||||
// Define input constants based on input tensor definition above:
|
||||
const int rank = params->rank;
|
||||
@ -491,7 +491,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* activation_state =
|
||||
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
const bool is_full_integer = input->type == kTfLiteInt8;
|
||||
|
@ -296,8 +296,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* weights_time =
|
||||
GetInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* activation_state =
|
||||
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
|
||||
const TfLiteTensor* activation_state =
|
||||
GetInput(context, node, kInputActivationStateTensor);
|
||||
|
||||
// Define input constants based on input tensor definition above:
|
||||
const int rank = params->rank;
|
||||
@ -409,7 +409,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* activation_state =
|
||||
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user