diff --git a/tensorflow/lite/experimental/ruy/test.h b/tensorflow/lite/experimental/ruy/test.h index 4ba0920dfe8..54101b308bb 100644 --- a/tensorflow/lite/experimental/ruy/test.h +++ b/tensorflow/lite/experimental/ruy/test.h @@ -1460,12 +1460,6 @@ void MakeSpecClampFields(Spec* spec) { using AccumScalar = typename Spec::AccumScalar; using DstScalar = typename Spec::DstScalar; - if (getenv("BENCHMARK_ONLY_MATMUL")) { - spec->clamp_min = -std::numeric_limits::infinity(); - spec->clamp_max = std::numeric_limits::infinity(); - return; - } - if (std::is_same::value) { // Returning raw accumulators, clamping is not supported. spec->clamp_min = std::numeric_limits::lowest(); @@ -1473,6 +1467,17 @@ void MakeSpecClampFields(Spec* spec) { return; } + if (getenv("BENCHMARK_ONLY_MATMUL")) { + if (std::is_floating_point::value) { + spec->clamp_min = -std::numeric_limits::infinity(); + spec->clamp_max = std::numeric_limits::infinity(); + } else { + spec->clamp_min = std::numeric_limits::lowest(); + spec->clamp_max = std::numeric_limits::max(); + } + return; + } + spec->clamp_min = std::numeric_limits::lowest() + 1; spec->clamp_max = std::numeric_limits::max() - 1; } @@ -1507,8 +1512,8 @@ template void TestSet::MakeSpec() { RUY_CHECK_EQ(life_stage, LifeStage::kHasLhsRhs); - if (!getenv("BENCHMARK_ONLY_MATMUL") && !benchmark && - (global_random_engine()() & 1)) { + if (!getenv("BENCHMARK_ONLY_MATMUL") && + (benchmark || (global_random_engine()() & 1))) { MakeRandomVector(RandomRange::kBias, rows, &bias_data); spec.bias = bias_data.data(); }