diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc
index 74bdb9a0827..f437f74289c 100644
--- a/tensorflow/lite/kernels/activations.cc
+++ b/tensorflow/lite/kernels/activations.cc
@@ -65,6 +65,8 @@ struct SoftmaxOpData {
 struct LogSoftmaxOpData : public OpData {
   int32_t reverse_scaling_divisor = 0;
   int32_t reverse_scaling_right_shift = 0;
+  struct SoftmaxParams params = {};
+  float f_table[256];
 };
 
 struct LeakyReluOpData : public OpData {
@@ -469,23 +471,28 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
   TF_LITE_ENSURE_EQ(context, input->type, output->type);
 
   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
+    TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
+    static const double kBeta = 1.0;
     if (input->type == kTfLiteUInt8) {
       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
+      data->params.table = data->f_table;
+      optimized_ops::PopulateSoftmaxLookupTable(&data->params,
+                                                input->params.scale, kBeta);
+      data->params.zero_point = output->params.zero_point;
+      data->params.scale = output->params.scale;
     }
     if (input->type == kTfLiteInt8) {
       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127);
+      static const int kScaledDiffIntegerBits = 5;
+      tflite::PreprocessLogSoftmaxScalingExp(
+          kBeta, input->params.scale, kScaledDiffIntegerBits,
+          &data->input_multiplier, &data->input_left_shift,
+          &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift);
+      data->reverse_scaling_right_shift *= -1;
+      data->diff_min =
+          -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
+                                              data->input_left_shift);
     }
-    TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
-
-    static const double kBeta = 1.0;
-    static const int kScaledDiffIntegerBits = 5;
-    tflite::PreprocessLogSoftmaxScalingExp(
-        kBeta, input->params.scale, kScaledDiffIntegerBits,
-        &data->input_multiplier, &data->input_left_shift,
-        &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift);
-    data->reverse_scaling_right_shift *= -1;
-    data->diff_min = -1.0 * tflite::CalculateInputRadius(
-                                kScaledDiffIntegerBits, data->input_left_shift);
   }
 
   return context->ResizeTensor(context, output,
@@ -927,16 +934,12 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
       return kTfLiteOk;
     }
     case kTfLiteUInt8: {
-      SoftmaxParams op_params;
-      op_params.input_multiplier = data->input_multiplier;
-      op_params.input_left_shift = data->input_left_shift;
-      op_params.reverse_scaling_divisor = data->reverse_scaling_divisor;
-      op_params.reverse_scaling_right_shift = data->reverse_scaling_right_shift;
-      op_params.diff_min = data->diff_min;
+      SoftmaxParams op_params = data->params;
       if (kernel_type == kGenericOptimized) {
         optimized_ops::LogSoftmax(
-            op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
-            GetTensorShape(output), GetTensorData<uint8_t>(output));
+            op_params, input->params.scale, GetTensorShape(input),
+            GetTensorData<uint8_t>(input), GetTensorShape(output),
+            GetTensorData<uint8_t>(output));
       } else {
         reference_ops::LogSoftmax(
             op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc
index 67fbee58162..b837afacfb5 100644
--- a/tensorflow/lite/kernels/activations_test.cc
+++ b/tensorflow/lite/kernels/activations_test.cc
@@ -1253,6 +1253,7 @@ TEST(FloatActivationsOpTest, LogSoftmax) {
 
 TEST(QuantizedActivationsOpTest, LogSoftmaxUint8) {
   const float kLogSoftmaxQuantizedTolerance = 16 / 256.0;
+  // Corresponds to input scale of 20/255.
   QuantizedActivationsOpModel m(
       BuiltinOperator_LOG_SOFTMAX,
       /*input=*/{TensorType_UINT8, {2, 4}, -10, 10},
diff --git a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc
index b98e8234454..72e4685d1e9 100644
--- a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc
+++ b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc
@@ -171,13 +171,19 @@ void RunOneLogSoftmaxTest(const uint8* input_data,
                                                      input_beta_left_shift);
 
   SoftmaxParams params;
+  float table[256];
   params.input_multiplier = input_beta_multiplier;
   params.input_left_shift = input_beta_left_shift;
   params.reverse_scaling_divisor = reverse_scaling_divisor;
   params.reverse_scaling_right_shift = reverse_scaling_right_shift;
   params.diff_min = diff_min;
-  optimized_ops::LogSoftmax(params, shape_common, input_data, shape_common,
-                            optimized_logsoftmax_output.data());
+
+  params.scale = 1.0f / 16.0f;
+  params.zero_point = 255;
+  params.table = table;
+  optimized_ops::PopulateSoftmaxLookupTable(&params, input_scale, beta);
+  optimized_ops::LogSoftmax(params, input_scale, shape_common, input_data,
+                            shape_common, optimized_logsoftmax_output.data());
   reference_ops::LogSoftmax(params, shape_common, input_data, shape_common,
                             reference_quant_logsoftmax_output.data());
 
@@ -186,7 +192,7 @@ void RunOneLogSoftmaxTest(const uint8* input_data,
                            shape_common, "Optimized vs float reference", false);
   CheckOutputData<uint8_t>(optimized_logsoftmax_output.data(),
                            reference_quant_logsoftmax_output.data(),
-                           shape_common, "Optimized vs quant reference", true);
+                           shape_common, "Optimized vs quant reference", false);
   CheckOutputData<uint8_t>(reference_quant_logsoftmax_output.data(),
                            reference_float_logsoftmax_output.data(),
                            shape_common, "Quant reference vs float reference",
diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h
index b9305169065..9edef104cfb 100644
--- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -4232,7 +4232,8 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
   params.reverse_scaling_divisor = reverse_scaling_divisor;
   params.reverse_scaling_right_shift = reverse_scaling_right_shift;
   params.diff_min = diff_min;
-  LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+  reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
+                            output_data);
 }
 
 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
@@ -4240,10 +4241,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
                        int32 reverse_scaling_divisor,
                        int32 reverse_scaling_right_shift, int diff_min,
                        uint8* output_data, const Dims<4>& output_dims) {
-  LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
-             input_left_shift, reverse_scaling_divisor,
-             reverse_scaling_right_shift, diff_min, output_data,
-             DimsToShape(output_dims));
+  reference_ops::LogSoftmax(
+      input_data, DimsToShape(input_dims), input_multiplier, input_left_shift,
+      reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
+      output_data, DimsToShape(output_dims));
 }
 
 inline void Logistic(const LogisticParams& params,
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index ba7b0fd2f32..815a32a25e0 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -3621,93 +3621,77 @@ inline void LogSoftmax(const SoftmaxParams& params,
   }
 }
 
-// Currently just a copy of the reference code.
+// Backwards compatibility. Less optimized than below version.
 inline void LogSoftmax(const SoftmaxParams& params,
                        const RuntimeShape& input_shape, const uint8* input_data,
                        const RuntimeShape& output_shape, uint8* output_data) {
-  gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
-  const int32 input_multiplier = params.input_multiplier;
-  const int32 input_left_shift = params.input_left_shift;
-  const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
-  const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
-  const int diff_min = params.diff_min;
-  // The representation chosen for the input to the exp() function is Q5.26.
-  // We need to leave extra space since values that we skip might be as large as
-  // -32 before multiplying by input_beta_multiplier, and therefore as large as
-  // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
-  // accumulation, but exp(-16) definitely is.
-  static constexpr int kScaledDiffIntegerBits = 5;
-  static constexpr int kAccumulationIntegerBits = 12;
-  static constexpr int kOutputIntegerBits = 4;
-  using FixedPointScaledDiff =
-      gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
-  using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+  reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
+                            output_data);
+}
 
+// Compute LogSoftmax as (x - x_max) - ln(sum(e^(x_i - x_max)...)
+// as done in tf.nn.log_softmax to prevent underflow and overflow.
+// This is in contrast to just log(softmax(x))
+//
+// To handle quantization, first dequantize the inputs (from doing
+// e^(input scale * val) where we ignore the zero point since it cancels
+// out during subtraction due to the ln) and do a rescale at the end to int8.
+//
+// Notably this makes use of float and is intended as the optimized
+// form for quantized execution on CPU. For a fully integer version,
+// see the reference op.
+//
+// TODO(tflite): notes for optimization:
+// 1) See if e^ is also bottleneck in the reference fully-integer
+// version and apply lookup there and compare.
+// 2) Time spent is currently split between computing max_val, the
+// rint call, and the computation of log_prob.
+inline void LogSoftmax(const SoftmaxParams& params, float input_scale,
+                       const RuntimeShape& input_shape, const uint8* input_data,
+                       const RuntimeShape& output_shape, uint8* output_data) {
+  gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
   const int trailing_dim = input_shape.DimensionsCount() - 1;
-  const int outer_size =
+  const int excluding_last_dim =
       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
-  const int depth =
+  const int last_dim =
       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
 
-  for (int i = 0; i < outer_size; ++i) {
-    const uint8* block_input_data = input_data + i * depth;
-    uint8* block_output_data = output_data + i * depth;
-    uint8 max_in_row = 0;
-    for (int c = 0; c < depth; ++c) {
-      max_in_row = std::max(max_in_row, block_input_data[c]);
+  const int32_t clamp_max = std::numeric_limits<uint8>::max();
+  const int32_t clamp_min = std::numeric_limits<uint8>::min();
+  for (int i = 0; i < excluding_last_dim; ++i) {
+    uint8_t max_val = std::numeric_limits<uint8>::min();
+    // Find max quantized value.
+    for (int j = 0; j < last_dim; ++j) {
+      max_val = std::max(max_val, input_data[j]);
     }
 
-    FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
-    for (int c = 0; c < depth; ++c) {
-      int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
-      if (input_diff >= diff_min) {
-        const int32 input_diff_rescaled =
-            MultiplyByQuantizedMultiplierGreaterThanOne(
-                input_diff, input_multiplier, input_left_shift);
-        const FixedPointScaledDiff scaled_diff_f8 =
-            FixedPointScaledDiff::FromRaw(input_diff_rescaled);
-        sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
-                                        exp_on_negative_values(scaled_diff_f8));
-      }
+    float sum_exp = 0.0f;
+    const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
+    // Offset into table to compute exp(scale*(x - xmax)) instead of
+    // exp(scale*(x)) to prevent overflow.
+    const float* table_offset = &params.table[max_uint8 - max_val];
+    // Calculate sum(exp(scale*(x - x_max))).
+    for (int j = 0; j < last_dim; ++j) {
+      sum_exp += table_offset[input_data[j]];
     }
+    const float log_sum_exp = std::log(sum_exp);
 
-    const int32 fixed_log_sum_of_exps =
-        log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
-            sum_of_exps)
-            .raw();
+    const float precomputed = input_scale * max_val + log_sum_exp;
+    for (int j = 0; j < last_dim; ++j) {
+      // Equivalent to input_scale * (input_data[j] - max_val) - log_sum_exp;
+      const float log_prob = input_scale * input_data[j] - precomputed;
 
-    // rescaled_diff_min is smallest representable in
-    // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
-    // log-sub-exps that will be subtracted in the loop.
-    //
-    // The thresholds diff_min, etc are negative.
-    const int rescaled_diff_min =
-        fixed_log_sum_of_exps + std::numeric_limits<int32>::lowest();
-    const int adjusted_diff_min =
-        std::max(diff_min - 1,  // Note use of > below instead of >= above.
-                 MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                     rescaled_diff_min, reverse_scaling_divisor,
-                     -reverse_scaling_right_shift));
+      // TODO(tflite): look into better solution.
+      // Use std::rint over std::round (which is used in
+      // FakeQuant) since it's multiple times faster on tested arm32.
+      const int32_t prob_quantized =
+          std::rint(log_prob / params.scale) + params.zero_point;
 
-    for (int c = 0; c < depth; ++c) {
-      int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
-      if (input_diff > adjusted_diff_min) {
-        const int32 input_diff_rescaled =
-            MultiplyByQuantizedMultiplierGreaterThanOne(
-                input_diff, input_multiplier, input_left_shift);
-        int32 unsat_output =
-            gemmlowp::RoundingDivideByPOT(
-                (input_diff_rescaled - fixed_log_sum_of_exps),
-                31 - kScaledDiffIntegerBits - kOutputIntegerBits) +
-            255;
-
-        block_output_data[c] = static_cast<uint8>(
-            std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
-      } else {
-        // Set output to smallest value.
-        block_output_data[c] = 0;
-      }
+      output_data[j] = static_cast<uint8_t>(
+          std::max(std::min(clamp_max, prob_quantized), clamp_min));
     }
+    input_data += last_dim;
+    output_data += last_dim;
   }
 }