Use GetVariableInput (or GetInput in const context) for the activation tensor in SVDF, instead of accessing it through its index.

PiperOrigin-RevId: 294748701
Change-Id: I6208f0af2b2db8d294fef559eb650b9c66cc9b6d
This commit is contained in:
Robert David 2020-02-12 13:53:18 -08:00 committed by TensorFlower Gardener
parent b9b688ebd2
commit 24fa7c6822
3 changed files with 9 additions and 12 deletions

View File

@ -39,7 +39,6 @@ namespace {
struct OpData {
int scratch_tensor_index;
bool float_weights_time_initialized;
int activation_state_tensor_index;
int32 effective_scale_1_a;
int effective_scale_1_b;
int32 effective_scale_2_a;
@ -80,8 +79,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
op_data->activation_state_tensor_index =
node->inputs->data[kInputActivationStateTensor];
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
@ -109,8 +106,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
}
TfLiteTensor* activation_state =
&context->tensors[op_data->activation_state_tensor_index];
const TfLiteTensor* activation_state =
GetInput(context, node, kInputActivationStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Check the shape of input state tensors.
@ -250,7 +247,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* activation_state =
&context->tensors[op_data->activation_state_tensor_index];
GetVariableInput(context, node, kInputActivationStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (weights_feature->type) {

View File

@ -377,8 +377,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* weights_time =
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* activation_state =
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
const TfLiteTensor* activation_state =
GetInput(context, node, kInputActivationStateTensor);
// Define input constants based on input tensor definition above:
const int rank = params->rank;
@ -491,7 +491,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* activation_state =
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
GetVariableInput(context, node, kInputActivationStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const bool is_full_integer = input->type == kTfLiteInt8;

View File

@ -296,8 +296,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* weights_time =
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* activation_state =
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
const TfLiteTensor* activation_state =
GetInput(context, node, kInputActivationStateTensor);
// Define input constants based on input tensor definition above:
const int rank = params->rank;
@ -409,7 +409,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* activation_state =
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
GetVariableInput(context, node, kInputActivationStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);