From 0c04770ad15b2a87f0b5ff82bf6b43d982f5c6f9 Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Fri, 15 Nov 2019 12:06:46 -0800 Subject: [PATCH] Ruy: Permit GEMV code to thread if thread_count above 1 PiperOrigin-RevId: 280708178 Change-Id: I519ca15e4b52d2e4483dae63418ebfb7b6b8e31f --- tensorflow/lite/experimental/ruy/trmul.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/experimental/ruy/trmul.cc b/tensorflow/lite/experimental/ruy/trmul.cc index 0dcac3c6092..fbebc77de88 100644 --- a/tensorflow/lite/experimental/ruy/trmul.cc +++ b/tensorflow/lite/experimental/ruy/trmul.cc @@ -261,11 +261,13 @@ int GetThreadCount(Context* context, int rows, int cols, int depth) { LoopStructure GetLoopStructure(int tentative_thread_count, int rows, int cols, int depth, int cache_friendly_traversal_threshold) { - if (cols == 1) { // Use a simple loop for the GEMV case. - return LoopStructure::kSimple; - } else if (tentative_thread_count == 1 && - (rows + cols) * depth < cache_friendly_traversal_threshold) { - return LoopStructure::kSimple; + if (tentative_thread_count == 1) { + // If we are in the GEMV case or the size is below the + // threshold, stay with the simple loop structure. + if ((cols == 1) || + (rows + cols) * depth < cache_friendly_traversal_threshold) { + return LoopStructure::kSimple; + } } return LoopStructure::kGeneral; }