diff --git a/tensorflow/lite/experimental/micro/kernels/svdf.cc b/tensorflow/lite/experimental/micro/kernels/svdf.cc index 866a0286553..756c3c7ccd3 100644 --- a/tensorflow/lite/experimental/micro/kernels/svdf.cc +++ b/tensorflow/lite/experimental/micro/kernels/svdf.cc @@ -359,7 +359,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // [3] = Bias (optional), {1, num_units} // [4] = Activation State (variable), // {2, batch_size, memory_size * num_filters} - TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + // TODO(kreeger): Use input tensor as variable until scratch tensor allocation + // has been implemented (cl/263032056) + // TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 6); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* weights_feature = GetInput(context, node, kWeightsFeatureTensor); @@ -408,7 +411,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // [0] = Holds dot-product of time-forward calculations in // ApplyTimeWeightsBiasAndActivation(): // float, {2, batch_size, num_filters} - TfLiteTensor* scratch_tensor = GetTemporary(context, node, 0); + // TODO(kreeger): Use input tensor as variable until scratch tensor allocation + // has been implemented (cl/263032056) + // TfLiteTensor* scratch_tensor = GetTemporary(context, node, 0); + TfLiteTensor* scratch_tensor = &context->tensors[node->inputs->data[5]]; + TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, NumDimensions(scratch_tensor), 2); TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[0], batch_size); @@ -479,7 +486,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Full-float SVDF only uses the one shared scratch tensor (see above for // usage). - TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1); + // TODO(kreeger): Use input tensor as variable until scratch tensor + // allocation has been implemented (cl/263032056) + // TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1); } // Validate Tensor Output: @@ -504,7 +513,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetInput(context, node, kWeightsTimeTensor); const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); + // TODO(kreeger): Use input tensor as variable until scratch tensor allocation + // has been implemented (cl/263032056) + // TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* scratch = &context->tensors[node->inputs->data[5]]; TfLiteTensor* activation_state = &context->tensors[node->inputs->data[kInputActivationStateTensor]]; diff --git a/tensorflow/lite/experimental/micro/kernels/svdf_test.cc b/tensorflow/lite/experimental/micro/kernels/svdf_test.cc index de3dcb83f66..d1cbe3d6c95 100644 --- a/tensorflow/lite/experimental/micro/kernels/svdf_test.cc +++ b/tensorflow/lite/experimental/micro/kernels/svdf_test.cc @@ -146,7 +146,10 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units, } // Bias is an optional tensor: - int inputs_array_data[] = {5, 0, 1, 2, kOptionalTensor, 3}; + // TODO(kreeger): Use input tensor as variable until scratch tensor allocation + // has been implemented (cl/263032056) + // int inputs_array_data[] = {5, 0, 1, 2, kOptionalTensor, 3}; + int inputs_array_data[] = {6, 0, 1, 2, kOptionalTensor, 3, 5}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 4};