diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD
index 7ccf30d45fd..c10c015c0cb 100644
--- a/tensorflow/lite/testing/BUILD
+++ b/tensorflow/lite/testing/BUILD
@@ -6,6 +6,7 @@ load(
     "merged_test_models",
 )
 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
+load("//tensorflow/lite/testing:tflite_model_test.bzl", "tflite_model_test")
 load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
 load(
     "//tensorflow:tensorflow.bzl",
@@ -21,7 +22,10 @@ package(
     licenses = ["notice"],  # Apache 2.0
 )
 
-exports_files(["generated_examples_zip_test.cc"])
+exports_files([
+    "generated_examples_zip_test.cc",
+    "tflite_diff_example_test.cc",
+])
 
 [gen_zip_test(
     name = "zip_test_%s" % test_name,
@@ -309,13 +313,22 @@ cc_library(
         ":join",
         ":split",
         ":test_runner",
-        "//tensorflow/core:core_cpu",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:tensorflow",
-        "//tensorflow/lite:string_util",
         "@com_google_absl//absl/strings",
-    ],
+        "//tensorflow/lite:string_util",
+    ] + select({
+        "//conditions:default": [
+            "//tensorflow/core:core_cpu",
+            "//tensorflow/core:framework",
+            "//tensorflow/core:lib",
+            "//tensorflow/core:tensorflow",
+        ],
+        "//tensorflow:android": [
+            "//tensorflow/core:android_tensorflow_lib",
+        ],
+        "//tensorflow:ios": [
+            "//tensorflow/core:ios_tensorflow_lib",
+        ],
+    }),
 )
 
 tf_cc_test(
@@ -342,9 +355,18 @@ cc_library(
         ":join",
         ":split",
         ":tf_driver",
-        "//tensorflow/core:framework",
         "//tensorflow/lite:string",
-    ],
+    ] + select({
+        "//conditions:default": [
+            "//tensorflow/core:framework",
+        ],
+        "//tensorflow:android": [
+            "//tensorflow/core:android_tensorflow_lib",
+        ],
+        "//tensorflow:ios": [
+            "//tensorflow/core:ios_tensorflow_lib",
+        ],
+    }),
 )
 
 tf_cc_test(
@@ -404,6 +426,7 @@ cc_library(
         ":split",
         ":tflite_diff_util",
         ":tflite_driver",
+        "@com_google_absl//absl/strings",
     ] + select({
         "//conditions:default": [
             "//tensorflow/core:framework_internal",
@@ -453,6 +476,20 @@ tf_cc_binary(
     ],
 )
 
+tflite_model_test(
+    name = "tflite_model_example_test",
+    input_layer = "a,b,c,d",
+    input_layer_shape = "1,8,8,3:1,8,8,3:1,8,8,3:1,8,8,3",
+    input_layer_type = "float,float,float,float",
+    output_layer = "x,y",
+    tags = [
+        "no_cuda_on_cpu_tap",
+        "no_oss",  # needs test data
+        "tflite_not_portable",  # TODO(b/134772701): Enable after making this a proper GTest.
+    ],
+    tensorflow_model_file = "//tensorflow/lite:testdata/multi_add.pb",
+)
+
 cc_library(
     name = "string_util_lib",
     srcs = ["string_util.cc"],
diff --git a/tensorflow/lite/testing/generate_testspec.cc b/tensorflow/lite/testing/generate_testspec.cc
index 74e4d254983..99021c9f317 100644
--- a/tensorflow/lite/testing/generate_testspec.cc
+++ b/tensorflow/lite/testing/generate_testspec.cc
@@ -13,34 +13,31 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include <iostream>
-
 #include "tensorflow/lite/testing/generate_testspec.h"
+
+#include <iostream>
+#include <random>
+
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/lite/testing/join.h"
 #include "tensorflow/lite/testing/split.h"
 #include "tensorflow/lite/testing/tf_driver.h"
-#include "tensorflow/core/framework/types.h"
 
 namespace tflite {
 namespace testing {
+namespace {
 
-template <typename T>
-void GenerateCsv(const std::vector<int>& shape, float min, float max,
-                 string* out) {
-  auto random_float = [](float min, float max) {
-    static unsigned int seed;
-    return min + (max - min) * static_cast<float>(rand_r(&seed)) / RAND_MAX;
-  };
-
-  std::function<T(int)> random_t = [&](int) {
-    return static_cast<T>(random_float(min, max));
-  };
-  std::vector<T> data = GenerateRandomTensor(shape, random_t);
+template <typename T, typename RandomEngine, typename RandomDistribution>
+void GenerateCsv(const std::vector<int>& shape, RandomEngine* engine,
+                 RandomDistribution distribution, string* out) {
+  std::vector<T> data =
+      GenerateRandomTensor<T>(shape, [&]() { return distribution(*engine); });
   *out = Join(data.data(), data.size(), ",");
 }
 
+template <typename RandomEngine>
 std::vector<string> GenerateInputValues(
-    const std::vector<string>& input_layer,
+    RandomEngine* engine, const std::vector<string>& input_layer,
     const std::vector<string>& input_layer_type,
     const std::vector<string>& input_layer_shape) {
   std::vector<string> input_values;
@@ -52,19 +49,29 @@ std::vector<string> GenerateInputValues(
 
     switch (type) {
       case tensorflow::DT_FLOAT:
-        GenerateCsv<float>(shape, -0.5, 0.5, &input_values[i]);
+        GenerateCsv<float>(shape, engine,
+                           std::uniform_real_distribution<float>(-0.5, 0.5),
+                           &input_values[i]);
         break;
       case tensorflow::DT_UINT8:
-        GenerateCsv<uint8_t>(shape, 0, 255, &input_values[i]);
+        GenerateCsv<uint8_t>(shape, engine,
+                             std::uniform_int_distribution<uint8_t>(0, 255),
+                             &input_values[i]);
         break;
       case tensorflow::DT_INT32:
-        GenerateCsv<int32_t>(shape, -100, 100, &input_values[i]);
+        GenerateCsv<int32_t>(shape, engine,
+                             std::uniform_int_distribution<int32_t>(-100, 100),
+                             &input_values[i]);
         break;
       case tensorflow::DT_INT64:
-        GenerateCsv<int64_t>(shape, -100, 100, &input_values[i]);
+        GenerateCsv<int64_t>(shape, engine,
+                             std::uniform_int_distribution<int64_t>(-100, 100),
+                             &input_values[i]);
         break;
       case tensorflow::DT_BOOL:
-        GenerateCsv<int>(shape, 0.01, 1.99, &input_values[i]);
+        GenerateCsv<int>(shape, engine,
+                         std::uniform_int_distribution<int>(0, 1),
+                         &input_values[i]);
         break;
       default:
         fprintf(stderr, "Unsupported type %d (%s) when generating testspec.\n",
@@ -76,6 +83,8 @@ std::vector<string> GenerateInputValues(
   return input_values;
 }
 
+}  // namespace
+
 bool GenerateTestSpecFromTensorflowModel(
     std::iostream& stream, const string& tensorflow_model_path,
     const string& tflite_model_path, int num_invocations,
@@ -109,11 +118,12 @@ bool GenerateTestSpecFromTensorflowModel(
   stream << "}\n";
 
   // Generate inputs.
+  std::mt19937 random_engine;
   for (int i = 0; i < num_invocations; ++i) {
     // Note that the input values are random, so each invocation will have a
     // different set.
-    std::vector<string> input_values =
-        GenerateInputValues(input_layer, input_layer_type, input_layer_shape);
+    std::vector<string> input_values = GenerateInputValues(
+        &random_engine, input_layer, input_layer_type, input_layer_shape);
     if (input_values.empty()) {
       std::cerr << "Unable to generate input values for the TensorFlow model. "
                    "Make sure the correct values are defined for "
diff --git a/tensorflow/lite/testing/generate_testspec.h b/tensorflow/lite/testing/generate_testspec.h
index fe7e6ddb3fb..58f8065972b 100644
--- a/tensorflow/lite/testing/generate_testspec.h
+++ b/tensorflow/lite/testing/generate_testspec.h
@@ -15,6 +15,7 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_TESTING_GENERATE_TESTSPEC_H_
 #define TENSORFLOW_LITE_TESTING_GENERATE_TESTSPEC_H_
 
+#include <algorithm>
 #include <functional>
 #include <iostream>
 #include <vector>
@@ -46,19 +47,16 @@ bool GenerateTestSpecFromTensorflowModel(
     const std::vector<string>& output_layer);
 
 // Generates random values that are filled into the tensor.
-// random_func returns the generated random element at given index.
-template <typename T>
+template <typename T, typename RandomFunction>
 std::vector<T> GenerateRandomTensor(const std::vector<int>& shape,
-                                    const std::function<T(int)>& random_func) {
+                                    RandomFunction random_func) {
   int64_t num_elements = 1;
   for (const int dim : shape) {
     num_elements *= dim;
   }
 
   std::vector<T> result(num_elements);
-  for (int i = 0; i < num_elements; i++) {
-    result[i] = random_func(i);
-  }
+  std::generate_n(result.data(), num_elements, random_func);
   return result;
 }
 
diff --git a/tensorflow/lite/testing/generate_testspec_test.cc b/tensorflow/lite/testing/generate_testspec_test.cc
index 4450da289d2..1887c8f0cd0 100644
--- a/tensorflow/lite/testing/generate_testspec_test.cc
+++ b/tensorflow/lite/testing/generate_testspec_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/lite/testing/generate_testspec.h"
 
+#include <random>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
@@ -22,16 +24,16 @@ namespace testing {
 namespace {
 
 TEST(GenerateRandomTensor, FloatValue) {
-  static unsigned int seed = 0;
-  std::function<float(int)> float_rand = [](int idx) {
-    return static_cast<float>(rand_r(&seed)) / RAND_MAX - 0.5f;
+  std::mt19937 random_engine;
+  auto random_func = [&]() {
+    return std::uniform_real_distribution<float>(-0.5, 0.5)(random_engine);
   };
 
   std::set<float> values;
   float sum_x_square = 0.0f;
   float sum_x = 0.0f;
   for (int i = 0; i < 100; i++) {
-    const auto& data = GenerateRandomTensor<float>({1, 3, 4}, float_rand);
+    const auto& data = GenerateRandomTensor<float>({1, 3, 4}, random_func);
     for (float value : data) {
       values.insert(value);
       sum_x_square += value * value;
diff --git a/tensorflow/lite/testing/tflite_diff_flags.h b/tensorflow/lite/testing/tflite_diff_flags.h
index 8b1205e58d7..7022cb03ad1 100644
--- a/tensorflow/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/lite/testing/tflite_diff_flags.h
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include <cstring>
 
+#include "absl/strings/match.h"
 #include "tensorflow/core/util/command_line_flags.h"
 #include "tensorflow/lite/testing/split.h"
 #include "tensorflow/lite/testing/tflite_diff_util.h"
@@ -76,11 +77,11 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
 
   TfLiteDriver::DelegateType delegate = TfLiteDriver::DelegateType::kNone;
   if (!values.delegate_name.empty()) {
-    if (delegate_name == "NNAPI") {
+    if (absl::EqualsIgnoreCase(values.delegate_name, "nnapi")) {
       delegate = TfLiteDriver::DelegateType::kNnapi;
-    } else if (values.delegate_name == "GPU") {
+    } else if (absl::EqualsIgnoreCase(values.delegate_name, "gpu")) {
       delegate = TfLiteDriver::DelegateType::kGpu;
-    } else if (values.delegate_name == "FLEX") {
+    } else if (absl::EqualsIgnoreCase(values.delegate_name, "flex")) {
       delegate = TfLiteDriver::DelegateType::kFlex;
     } else {
       fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
diff --git a/tensorflow/lite/testing/tflite_model_test.bzl b/tensorflow/lite/testing/tflite_model_test.bzl
new file mode 100644
index 00000000000..a68d5e649b1
--- /dev/null
+++ b/tensorflow/lite/testing/tflite_model_test.bzl
@@ -0,0 +1,152 @@
+"""Definition for tflite_model_test rule that runs a TF Lite model accuracy test.
+
+This rule generates targets to run a diff-based model accuracy test against
+synthetic, random inputs. Future work will allow injection of "golden" inputs,
+as well as more robust execution on mobile devices.
+
+Example usage:
+
+tflite_model_test(
+    name = "simple_diff_test",
+    tensorflow_model_file = "//tensorflow/lite:testdata/multi_add.pb",
+    input_layer = "a,b,c,d",
+    input_layer_shape = "1,8,8,3:1,8,8,3:1,8,8,3:1,8,8,3",
+    input_layer_type = "float,float,float,float",
+    output_layer = "x,y",
+)
+"""
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+def tflite_model_test(
+        name,
+        tensorflow_model_file,
+        input_layer,
+        input_layer_type,
+        input_layer_shape,
+        output_layer,
+        inference_type = "float",
+        extra_conversion_flags = [],
+        num_runs = 20,
+        tags = [],
+        size = "large"):
+    """Create test targets for validating TFLite model execution relative to TF.
+
+    Args:
+      name: Generated test target name. Note that multiple targets may be
+          created if `delegates` are provided.
+      tensorflow_model_file: The binary GraphDef proto to run the benchmark on.
+      input_layer: A list of input tensors to use in the test.
+      input_layer_shape: The shape of the input layer in csv format.
+      input_layer_type: The data type of the input layer(s) (int, float, etc).
+      output_layer: The layer that output should be read from.
+      inference_type: The data type for inference and output.
+      extra_conversion_flags: Extra flags to append to those used for converting
+          models to the tflite format.
+      num_runs: Number of synthetic test cases to run.
+      tags: Extra tags to apply to the test targets.
+      size: The test size to use.
+    """
+
+    conversion_flags = [
+        "--input_shapes=%s" % input_layer_shape,
+        "--input_arrays=%s" % input_layer,
+        "--output_arrays=%s" % output_layer,
+    ] + extra_conversion_flags
+
+    tflite_model_file = make_tflite_files(
+        target_name = "tflite_" + name + "_model",
+        model_file = tensorflow_model_file,
+        conversion_flags = conversion_flags,
+        inference_type = inference_type,
+    )
+
+    diff_args = [
+        # TODO(b/134772701): Find a better way to extract the absolute path from
+        # a target without relying on $(location), which doesn't work with some
+        # mobile test variants. For now we use $(location), but something like
+        # the following is what we want for mobile tests:
+        # "--tensorflow_model=%s" % tensorflow_model_file.replace("//", "").replace(":", "/"),
+        # "--tflite_model=%s" % tflite_model_file.replace("//", "").replace(":", "/"),
+        "--tensorflow_model=$(location %s)" % tensorflow_model_file,
+        "--tflite_model=$(location %s)" % tflite_model_file,
+        "--input_layer=%s" % input_layer,
+        "--input_layer_type=%s" % input_layer_type,
+        "--input_layer_shape=%s" % input_layer_shape,
+        "--output_layer=%s" % output_layer,
+        "--num_runs_per_pass=%s" % num_runs,
+    ]
+
+    tf_cc_test(
+        name = name,
+        size = size,
+        srcs = ["//tensorflow/lite/testing:tflite_diff_example_test.cc"],
+        args = diff_args,
+        data = [
+            tensorflow_model_file,
+            tflite_model_file,
+        ],
+        tags = tags,
+        deps = [
+            "//tensorflow/lite/testing:init_tensorflow",
+            "//tensorflow/lite/testing:tflite_diff_flags",
+            "//tensorflow/lite/testing:tflite_diff_util",
+        ],
+    )
+
+def make_tflite_files(
+        target_name,
+        model_file,
+        conversion_flags,
+        inference_type):
+    """Uses TFLite to convert and input proto to tflite flatbuffer format.
+
+    Args:
+      target_name: Generated target name.
+      model_file: the path to the input file.
+      conversion_flags: parameters to pass to tflite for conversion.
+      inference_type: The data type for inference and output.
+    Returns:
+      The name of the generated file.
+    """
+    flags = [] + conversion_flags
+    if inference_type == "float":
+        flags += [
+            "--inference_type=FLOAT",
+            "--inference_input_type=FLOAT",
+        ]
+    elif inference_type == "quantized":
+        flags += [
+            "--inference_type=QUANTIZED_UINT8",
+            "--inference_input_type=QUANTIZED_UINT8",
+        ]
+    else:
+        fail("Invalid inference type (%s). Expected 'float' or 'quantized'" % inference_type)
+
+    srcs = [model_file]
+
+    # Convert from Tensorflow graphdef to tflite model.
+    output_file = target_name + ".fb"
+
+    tool = "//tensorflow/lite/python:tflite_convert"
+    cmd = ("$(location %s) " +
+           " --graph_def_file=$(location %s)" +
+           " --output_file=$(location %s)" +
+           " --input_format=TENSORFLOW_GRAPHDEF" +
+           " --output_format=TFLITE " +
+           " ".join(flags)
+               .replace("std_value", "std_dev_value")
+               .replace("quantize_weights=true", "quantize_weights"))
+
+    native.genrule(
+        name = target_name,
+        srcs = srcs,
+        tags = ["manual"],
+        outs = [
+            output_file,
+        ],
+        cmd = cmd % (tool, model_file, output_file),
+        tools = [tool],
+        visibility = ["//tensorflow/lite/testing:__subpackages__"],
+    )
+    return "//%s:%s" % (native.package_name(), output_file)