diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 54fc124cde6..0e40095f255 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -234,7 +234,14 @@ cc_library(
         ],
         "//conditions:default": [],
     }),
-    deps = [
+    deps = select({
+        "//tensorflow/lite/delegates/gpu/cl:opencl_delegate_no_gl": [],
+        "//conditions:default": [
+            "//tensorflow/lite/delegates/gpu/gl:api2",
+        ],
+    }) + [
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/types:span",
         "//tensorflow/lite:kernel_api",
         "//tensorflow/lite:minimal_logging",
         "//tensorflow/lite/c:common",
@@ -247,9 +254,6 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:quantization_util",
         "//tensorflow/lite/delegates/gpu/common:status",
-        "//tensorflow/lite/delegates/gpu/gl:api2",
         "//tensorflow/lite/kernels/internal:optimized_base",
-        "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/types:span",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/api.h b/tensorflow/lite/delegates/gpu/api.h
index 1dfeeebd700..7892d0ce2f6 100644
--- a/tensorflow/lite/delegates/gpu/api.h
+++ b/tensorflow/lite/delegates/gpu/api.h
@@ -43,9 +43,14 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
-#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
 #include <vulkan/vulkan.h>
 
+#define GL_NO_PROTOTYPES
+#define EGL_NO_PROTOTYPES
+#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
+#undef GL_NO_PROTOTYPES
+#undef EGL_NO_PROTOTYPES
+
 namespace tflite {
 namespace gpu {
 
diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD
index ffb9d6204ad..9155bc1166a 100644
--- a/tensorflow/lite/delegates/gpu/cl/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/BUILD
@@ -9,23 +9,34 @@ package(
     licenses = ["notice"],  # Apache 2.0
 )
 
+config_setting(
+    name = "opencl_delegate_no_gl",
+    values = {"copt": "-DCL_DELEGATE_NO_GL"},
+)
+
 cc_library(
     name = "api",
     srcs = ["api.cc"],
     hdrs = ["api.h"],
-    deps = [
+    deps = select({
+        ":opencl_delegate_no_gl": [],
+        "//conditions:default": [
+            ":egl_sync",
+            ":gl_interop",
+        ],
+    }) + [
         ":cl_command_queue",
         ":cl_errors",
         ":cl_event",
-        ":egl_sync",
         ":environment",
-        ":gl_interop",
         ":inference_context",
         ":opencl_wrapper",
         ":precision",
         ":tensor",
         ":tensor_type",
         ":tensor_type_util",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/types:span",
         "//tensorflow/lite/delegates/gpu:api",
         "//tensorflow/lite/delegates/gpu/cl/kernels:converter",
         "//tensorflow/lite/delegates/gpu/common:data_type",
@@ -33,8 +44,6 @@ cc_library(
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
-        "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/types:span",
     ],
 )
 
diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc
index ffe0fb68881..503b04543b4 100644
--- a/tensorflow/lite/delegates/gpu/cl/api.cc
+++ b/tensorflow/lite/delegates/gpu/cl/api.cc
@@ -15,7 +15,9 @@ limitations under the License.
 
 #include "tensorflow/lite/delegates/gpu/cl/api.h"
 
-#include <EGL/eglext.h>
+#ifndef CL_DELEGATE_NO_GL
+#define CL_DELEGATE_ALLOW_GL
+#endif
 
 #include <algorithm>
 #include <cstring>
@@ -25,9 +27,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_errors.h"
 #include "tensorflow/lite/delegates/gpu/cl/cl_event.h"
-#include "tensorflow/lite/delegates/gpu/cl/egl_sync.h"
 #include "tensorflow/lite/delegates/gpu/cl/environment.h"
-#include "tensorflow/lite/delegates/gpu/cl/gl_interop.h"
 #include "tensorflow/lite/delegates/gpu/cl/inference_context.h"
 #include "tensorflow/lite/delegates/gpu/cl/kernels/converter.h"
 #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
@@ -39,6 +39,13 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
+#ifdef CL_DELEGATE_ALLOW_GL
+#include <EGL/eglext.h>
+
+#include "tensorflow/lite/delegates/gpu/cl/egl_sync.h"
+#include "tensorflow/lite/delegates/gpu/cl/gl_interop.h"
+#endif
+
 namespace tflite {
 namespace gpu {
 namespace cl {
@@ -87,11 +94,13 @@ class DefaultTensorTie : public TensorTie {
       const TensorTieDef& def,
       const TensorObjectConverterBuilder& converter_builder) {
     auto object_type = def.external_def.object_def.object_type;
+#ifdef CL_DELEGATE_ALLOW_GL
     if (def.external_def.object_def.user_provided &&
         GlClBufferCopier::IsSupported(def.external_def.object_def,
                                       def.internal_def.object_def)) {
       return true;
     }
+#endif
     return (object_type == ObjectType::OPENCL_BUFFER ||
             object_type == ObjectType::OPENCL_TEXTURE ||
             object_type == ObjectType::CPU_MEMORY) &&
@@ -138,6 +147,7 @@ class DefaultTensorTie : public TensorTie {
  private:
   absl::Status Init(TensorObjectConverterBuilder* converter_builder,
                     Environment* env) {
+#ifdef CL_DELEGATE_ALLOW_GL
     if (def().external_def.object_def.user_provided &&
         GlClBufferCopier::IsSupported(def().external_def.object_def,
                                       def().internal_def.object_def)) {
@@ -156,6 +166,12 @@ class DefaultTensorTie : public TensorTie {
       RETURN_IF_ERROR(converter_builder->MakeConverter(
           def().internal_def, def().external_def, &converter_to_));
     }
+#else
+    RETURN_IF_ERROR(converter_builder->MakeConverter(
+        def().external_def, def().internal_def, &converter_from_));
+    RETURN_IF_ERROR(converter_builder->MakeConverter(
+        def().internal_def, def().external_def, &converter_to_));
+#endif
     return MaybeAllocateExternalObject(env);
   }
 
@@ -275,6 +291,7 @@ class TwoStepTensorTie : public TensorTie {
   std::unique_ptr<TensorTie> outer_tie_;
 };
 
+#ifdef CL_DELEGATE_ALLOW_GL
 // Captures GL object into CL context before performing a conversion.
 class GlBufferHolder : public TensorTie {
  public:
@@ -351,6 +368,7 @@ class GlBufferHolder : public TensorTie {
   std::unique_ptr<TensorTie> tie_;
   TensorObject external_obj_;
 };
+#endif
 
 TensorObject TensorToObj(const Tensor& tensor) {
   if (tensor.GetStorageType() == TensorStorageType::BUFFER) {
@@ -365,19 +383,28 @@ TensorObject TensorToObj(const Tensor& tensor) {
 // Responsible for creating new tensor objects.
 class TensorTieFactory {
  public:
-  TensorTieFactory(Environment* env, InferenceContext* context,
-                   GlInteropFabric* gl_interop_fabric)
+  TensorTieFactory(Environment* env, InferenceContext* context
+#ifdef CL_DELEGATE_ALLOW_GL
+                   ,
+                   GlInteropFabric* gl_interop_fabric
+#endif
+                   )
       : env_(*env),
         context_(*context),
+#ifdef CL_DELEGATE_ALLOW_GL
         gl_interop_fabric_(gl_interop_fabric),
-        converter_builder_(NewConverterBuilder(env)) {}
+#endif
+        converter_builder_(NewConverterBuilder(env)) {
+  }
 
   bool IsSupported(const TensorTieDef& def) const {
     return IsValid(def.external_def.object_def) &&
            (NoopTensorTie::IsSupported(def) ||
             DefaultTensorTie::IsSupported(def, *converter_builder_) ||
+#ifdef CL_DELEGATE_ALLOW_GL
             (gl_interop_fabric_ &&
              GlBufferHolder::IsSupported(def, *converter_builder_)) ||
+#endif
             TwoStepTensorTie::IsSupported(def, *converter_builder_));
   }
 
@@ -392,10 +419,12 @@ class TensorTieFactory {
     if (DefaultTensorTie::IsSupported(def, *converter)) {
       return DefaultTensorTie::New(def, internal_object, converter, &env_, tie);
     }
+#ifdef CL_DELEGATE_ALLOW_GL
     if (gl_interop_fabric_ && GlBufferHolder::IsSupported(def, *converter)) {
       return GlBufferHolder::New(def, internal_object, converter,
                                  gl_interop_fabric_, &env_, tie);
     }
+#endif
     if (TwoStepTensorTie::IsSupported(def, *converter)) {
       return TwoStepTensorTie::New(def, internal_object, converter, &env_, tie);
     }
@@ -405,18 +434,29 @@ class TensorTieFactory {
  private:
   Environment& env_;
   InferenceContext& context_;
+#ifdef CL_DELEGATE_ALLOW_GL
   GlInteropFabric* gl_interop_fabric_;
+#endif
   std::unique_ptr<TensorObjectConverterBuilder> converter_builder_;
 };
 
 class InferenceRunnerImpl : public InferenceRunner {
  public:
   InferenceRunnerImpl(Environment* environment,
-                      std::unique_ptr<InferenceContext> context,
-                      std::unique_ptr<GlInteropFabric> gl_interop_fabric)
+                      std::unique_ptr<InferenceContext> context
+#ifdef CL_DELEGATE_ALLOW_GL
+                      ,
+                      std::unique_ptr<GlInteropFabric> gl_interop_fabric
+#endif
+                      )
       : queue_(environment->queue()),
-        context_(std::move(context)),
-        gl_interop_fabric_(std::move(gl_interop_fabric)) {}
+        context_(std::move(context))
+#ifdef CL_DELEGATE_ALLOW_GL
+        ,
+        gl_interop_fabric_(std::move(gl_interop_fabric))
+#endif
+  {
+  }
 
   absl::Status Initialize(const std::vector<TensorTieDef>& inputs,
                           const std::vector<TensorTieDef>& outputs,
@@ -464,9 +504,11 @@ class InferenceRunnerImpl : public InferenceRunner {
   }
 
   absl::Status Run() override {
+#ifdef CL_DELEGATE_ALLOW_GL
     if (gl_interop_fabric_) {
       RETURN_IF_ERROR(gl_interop_fabric_->Start());
     }
+#endif
     for (auto& obj : inputs_) {
       RETURN_IF_ERROR(obj->CopyFromExternalObject());
     }
@@ -475,9 +517,11 @@ class InferenceRunnerImpl : public InferenceRunner {
     for (auto& obj : outputs_) {
       RETURN_IF_ERROR(obj->CopyToExternalObject());
     }
+#ifdef CL_DELEGATE_ALLOW_GL
     if (gl_interop_fabric_) {
       RETURN_IF_ERROR(gl_interop_fabric_->Finish());
     }
+#endif
     return absl::OkStatus();
   }
 
@@ -506,7 +550,9 @@ class InferenceRunnerImpl : public InferenceRunner {
 
   CLCommandQueue* queue_;
   std::unique_ptr<InferenceContext> context_;
+#ifdef CL_DELEGATE_ALLOW_GL
   std::unique_ptr<GlInteropFabric> gl_interop_fabric_;
+#endif
   std::vector<std::unique_ptr<TensorTie>> inputs_;
   std::vector<std::unique_ptr<TensorTie>> outputs_;
 };
@@ -542,6 +588,7 @@ class InferenceBuilderImpl : public InferenceBuilder {
     }
     RETURN_IF_ERROR(context_->InitFromGraph(create_info, graph, environment_));
 
+#ifdef CL_DELEGATE_ALLOW_GL
     if (env_options.IsGlAware() &&
         IsGlSharingSupported(environment_->device())) {
       gl_interop_fabric_ = absl::make_unique<GlInteropFabric>(
@@ -549,6 +596,10 @@ class InferenceBuilderImpl : public InferenceBuilder {
     }
     tie_factory_ = absl::make_unique<TensorTieFactory>(
         environment_, context_.get(), gl_interop_fabric_.get());
+#else
+    tie_factory_ =
+        absl::make_unique<TensorTieFactory>(environment_, context_.get());
+#endif
 
     inputs_ = LinkTensors(graph, graph.inputs());
     outputs_ = LinkTensors(graph, graph.outputs());
@@ -599,6 +650,7 @@ class InferenceBuilderImpl : public InferenceBuilder {
   }
 
   absl::Status Build(std::unique_ptr<InferenceRunner>* runner) override {
+#ifdef CL_DELEGATE_ALLOW_GL
     if (gl_interop_fabric_ && !HasGlObjects()) {
       // destroy interop layer when there are no GL objects to avoid
       // extra synchronization cost.
@@ -606,6 +658,10 @@ class InferenceBuilderImpl : public InferenceBuilder {
     }
     auto runner_impl = absl::make_unique<InferenceRunnerImpl>(
         environment_, std::move(context_), std::move(gl_interop_fabric_));
+#else
+    auto runner_impl = absl::make_unique<InferenceRunnerImpl>(
+        environment_, std::move(context_));
+#endif
     RETURN_IF_ERROR(
         runner_impl->Initialize(inputs_, outputs_, tie_factory_.get()));
     *runner = std::move(runner_impl);
@@ -676,6 +732,7 @@ class InferenceBuilderImpl : public InferenceBuilder {
   }
 
   bool HasGlObjects() const {
+#ifdef CL_DELEGATE_ALLOW_GL
     auto is_gl = [](ObjectType t) {
       return t == ObjectType::OPENGL_SSBO || t == ObjectType::OPENGL_TEXTURE;
     };
@@ -689,6 +746,7 @@ class InferenceBuilderImpl : public InferenceBuilder {
         return true;
       }
     }
+#endif
     return false;
   }
 
@@ -703,7 +761,9 @@ class InferenceBuilderImpl : public InferenceBuilder {
   }
 
   std::unique_ptr<InferenceContext> context_;
+#ifdef CL_DELEGATE_ALLOW_GL
   std::unique_ptr<GlInteropFabric> gl_interop_fabric_;
+#endif
   Environment* environment_;
 
   std::vector<TensorTieDef> inputs_;
@@ -730,20 +790,25 @@ class InferenceEnvironmentImpl : public InferenceEnvironment {
       RETURN_IF_ERROR(CreateDefaultGPUDevice(&device));
     }
 
+#ifdef CL_DELEGATE_ALLOW_GL
     properties_.is_gl_sharing_supported = IsGlSharingSupported(device);
     properties_.is_gl_to_cl_fast_sync_supported =
         IsClEventFromEglSyncSupported(device);
     properties_.is_cl_to_gl_fast_sync_supported =
         IsEglSyncFromClEventSupported();
+#endif
 
     CLContext context;
     if (options_.context) {
+#ifdef CL_DELEGATE_ALLOW_GL
       if (options_.IsGlAware()) {
         return absl::InvalidArgumentError(
             "OpenCL context and EGL parameters are set in the same time.");
       }
+#endif
       context = CLContext(options_.context, /* has_ownership = */ false);
     } else {
+#ifdef CL_DELEGATE_ALLOW_GL
       if (options_.IsGlAware() && properties_.is_gl_sharing_supported) {
         RETURN_IF_ERROR(CreateCLGLContext(
             device,
@@ -753,6 +818,9 @@ class InferenceEnvironmentImpl : public InferenceEnvironment {
       } else {
         RETURN_IF_ERROR(CreateCLContext(device, &context));
       }
+#else
+      RETURN_IF_ERROR(CreateCLContext(device, &context));
+#endif
     }
 
     CLCommandQueue queue;
diff --git a/tensorflow/lite/delegates/gpu/cl/api.h b/tensorflow/lite/delegates/gpu/cl/api.h
index bddf7de3363..826d4f2bc78 100644
--- a/tensorflow/lite/delegates/gpu/cl/api.h
+++ b/tensorflow/lite/delegates/gpu/cl/api.h
@@ -16,6 +16,10 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_API_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_API_H_
 
+#ifdef CL_DELEGATE_NO_GL
+#define EGL_NO_PROTOTYPES
+#endif
+
 #include <EGL/egl.h>
 
 #include <cstdint>
diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.h b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.h
index 1a9fb73e6ab..e10489cc99b 100644
--- a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.h
+++ b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.h
@@ -16,8 +16,13 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_API_DELEGATE_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_API_DELEGATE_H_
 
+#define GL_NO_PROTOTYPES
+#define EGL_NO_PROTOTYPES
 #include <EGL/egl.h>
 #include <GLES3/gl31.h>
+#undef GL_NO_PROTOTYPES
+#undef EGL_NO_PROTOTYPES
+
 #include <stdint.h>
 
 #include "tensorflow/lite/c/common.h"
@@ -76,8 +81,8 @@ typedef struct {
 // .compile_options = {
 //   .precision_loss_allowed = false,
 // }
-// .egl_display = eglGetCurrentDisplay(),
-// .egl_context = eglGetCurrentContext();
+// .egl_display = EGL_NO_DISPLAY;
+// .egl_context = EGL_NO_CONTEXT;
 TFL_CAPI_EXPORT TfLiteDelegate* TfLiteGpuDelegateCreate_New(
     const TfLiteGpuDelegateOptions_New* options);
 
diff --git a/tensorflow/lite/delegates/gpu/cl/testing/BUILD b/tensorflow/lite/delegates/gpu/cl/testing/BUILD
index 723e4cd9e99..c82190ca0e6 100644
--- a/tensorflow/lite/delegates/gpu/cl/testing/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/testing/BUILD
@@ -16,3 +16,21 @@ cc_binary(
         "@com_google_absl//absl/time",
     ],
 )
+
+cc_binary(
+    name = "delegate_testing",
+    srcs = ["delegate_testing.cc"],
+    tags = [
+        "nobuilder",
+        "notap",
+    ],
+    deps = [
+        "//tensorflow/lite/delegates/gpu:delegate",
+        "//tensorflow/lite/delegates/gpu/cl:gpu_api_delegate",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
+        "//tensorflow/lite/kernels:builtin_ops",
+        "//tensorflow/lite/kernels:kernel_util",
+        "@com_google_absl//absl/time",
+    ],
+)
diff --git a/tensorflow/lite/delegates/gpu/cl/testing/delegate_testing.cc b/tensorflow/lite/delegates/gpu/cl/testing/delegate_testing.cc
new file mode 100644
index 00000000000..4e92f897d96
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/cl/testing/delegate_testing.cc
@@ -0,0 +1,158 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <chrono>  // NOLINT(build/c++11)
+#include <cmath>
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "absl/time/time.h"
+#include "tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h"
+#include "tensorflow/lite/delegates/gpu/delegate.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/register.h"
+
+namespace {
+
+void FillInputTensor(tflite::Interpreter* interpreter) {
+  for (int k = 0; k < interpreter->inputs().size(); ++k) {
+    float* p = interpreter->typed_input_tensor<float>(k);
+    const auto n =
+        tflite::NumElements(interpreter->tensor(interpreter->inputs()[k]));
+    for (int i = 0; i < n; ++i) {
+      p[i] = std::sin(i);
+    }
+  }
+}
+
+void CompareCPUGPUResults(tflite::Interpreter* cpu, tflite::Interpreter* gpu,
+                          float eps) {
+  for (int i = 0; i < cpu->outputs().size(); ++i) {
+    const float* cpu_out = cpu->typed_output_tensor<float>(i);
+    const float* gpu_out = gpu->typed_output_tensor<float>(i);
+    auto out_n = tflite::NumElements(cpu->tensor(cpu->outputs()[i]));
+    const int kMaxPrint = 10;
+    int printed = 0;
+    int total_different = 0;
+    for (int k = 0; k < out_n; ++k) {
+      const float abs_diff = fabs(cpu_out[k] - gpu_out[k]);
+      if (abs_diff > eps) {
+        total_different++;
+        if (printed < kMaxPrint) {
+          std::cout << "Output #" << i << ": element #" << k << ": CPU value - "
+                    << cpu_out[k] << ", GPU value - " << gpu_out[k]
+                    << ", abs diff - " << abs_diff << std::endl;
+          printed++;
+        }
+        if (printed == kMaxPrint) {
+          std::cout << "Printed " << kMaxPrint
+                    << " different elements, threshhold - " << eps
+                    << ", next different elements skipped" << std::endl;
+          printed++;
+        }
+      }
+    }
+    std::cout << "Total " << total_different
+              << " different elements, for output #" << i << ", threshhold - "
+              << eps << std::endl;
+  }
+}
+
+}  // namespace
+
+int main(int argc, char** argv) {
+  if (argc <= 1) {
+    std::cerr << "Expected model path as second argument." << std::endl;
+    return -1;
+  }
+
+  auto model = tflite::FlatBufferModel::BuildFromFile(argv[1]);
+  if (!model) {
+    std::cerr << "FlatBufferModel::BuildFromFile failed, model path - "
+              << argv[1] << std::endl;
+    return -1;
+  }
+  tflite::ops::builtin::BuiltinOpResolver op_resolver;
+  tflite::InterpreterBuilder builder(*model, op_resolver);
+
+  // CPU.
+  std::unique_ptr<tflite::Interpreter> cpu_inference;
+  builder(&cpu_inference);
+  if (!cpu_inference) {
+    std::cerr << "Failed to build CPU inference." << std::endl;
+    return -1;
+  }
+  auto status = cpu_inference->AllocateTensors();
+  if (status != kTfLiteOk) {
+    std::cerr << "Failed to AllocateTensors for CPU inference." << std::endl;
+    return -1;
+  }
+  FillInputTensor(cpu_inference.get());
+  status = cpu_inference->Invoke();
+  if (status != kTfLiteOk) {
+    std::cerr << "Failed to Invoke CPU inference." << std::endl;
+    return -1;
+  }
+
+  // GPU.
+  std::unique_ptr<tflite::Interpreter> gpu_inference;
+  builder(&gpu_inference);
+  if (!gpu_inference) {
+    std::cerr << "Failed to build GPU inference." << std::endl;
+    return -1;
+  }
+  TfLiteGpuDelegateOptionsV2 options;
+  options.is_precision_loss_allowed = -1;
+  options.inference_preference =
+      TFLITE_GPU_INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
+  options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
+  options.inference_priority2 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE;
+  options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
+  auto* gpu_delegate = TfLiteGpuDelegateV2Create(&options);
+  status = gpu_inference->ModifyGraphWithDelegate(gpu_delegate);
+  if (status != kTfLiteOk) {
+    std::cerr << "ModifyGraphWithDelegate failed." << std::endl;
+    return -1;
+  }
+  FillInputTensor(gpu_inference.get());
+  status = gpu_inference->Invoke();
+  if (status != kTfLiteOk) {
+    std::cerr << "Failed to Invoke GPU inference." << std::endl;
+    return -1;
+  }
+
+  CompareCPUGPUResults(cpu_inference.get(), gpu_inference.get(), 1e-4f);
+
+  // CPU inference latency.
+  auto start = std::chrono::high_resolution_clock::now();
+  cpu_inference->Invoke();
+  auto end = std::chrono::high_resolution_clock::now();
+  std::cout << "CPU time - " << (end - start).count() * 1e-6f << "ms"
+            << std::endl;
+
+  // GPU inference latency.
+  start = std::chrono::high_resolution_clock::now();
+  gpu_inference->Invoke();
+  end = std::chrono::high_resolution_clock::now();
+  std::cout << "GPU time(CPU->GPU->CPU) - " << (end - start).count() * 1e-6f
+            << "ms" << std::endl;
+
+  TfLiteGpuDelegateV2Delete(gpu_delegate);
+  return EXIT_SUCCESS;
+}
diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc
index 38e60753c59..0f2d9811633 100644
--- a/tensorflow/lite/delegates/gpu/delegate.cc
+++ b/tensorflow/lite/delegates/gpu/delegate.cc
@@ -34,10 +34,13 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/gl/api2.h"
 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/lite/minimal_logging.h"
 
+#ifndef CL_DELEGATE_NO_GL
+#include "tensorflow/lite/delegates/gpu/gl/api2.h"
+#endif
+
 namespace tflite {
 namespace gpu {
 namespace {
@@ -315,6 +318,7 @@ class DelegateKernel {
 
   absl::Status InitializeOpenGlApi(GraphFloat32* graph,
                                    std::unique_ptr<InferenceBuilder>* builder) {
+#ifndef CL_DELEGATE_NO_GL
     gl::InferenceEnvironmentOptions env_options;
     gl::InferenceEnvironmentProperties properties;
     RETURN_IF_ERROR(
@@ -330,13 +334,16 @@ class DelegateKernel {
     enforce_same_thread_ = true;
     TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
                          "Initialized OpenGL-based API.");
+#endif
     return absl::OkStatus();
   }
 
   // The Delegate instance that's shared across all DelegateKernel instances.
   Delegate* const delegate_;  // doesn't own the memory.
   std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
+#ifndef CL_DELEGATE_NO_GL
   std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
+#endif
   std::unique_ptr<InferenceRunner> runner_;
   std::vector<int64_t> input_indices_;
   std::vector<int64_t> output_indices_;