diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 44a4af182cc..3a0da779d3e 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -28,6 +28,7 @@ from tensorflow.python.eager import function as function_lib
 from tensorflow.python.eager import lift_to_graph
 from tensorflow.python.framework import func_graph as func_graph_module
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
@@ -695,15 +696,28 @@ class Function(object):
   def _initialize_uninitialized_variables(self, initializers):
     """Make and call a `ConcreteFunction` which initializes variables."""
 
+    if not initializers:
+      return
+
     # Note: using defun here avoids an infinite recursion.
     @function_lib.defun
     def initialize_variables():
       op_map = object_identity.ObjectIdentityDictionary()
-      for v, init in initializers:
+      # Stack all the var_is_initialized values into one tensor and intepret the
+      # numpy value. This will reduce the number of RPCs between client and
+      # worker in the remote case.
+      with ops.init_scope():
+        var_is_initialized = []
+        for v, _ in initializers:
+          var_is_initialized.append(
+              resource_variable_ops.var_is_initialized_op(v.handle))
+        var_is_initialized = array_ops.stack(var_is_initialized).numpy()
+
+      for (v, init), is_initialized in zip(initializers, var_is_initialized):
         with ops.init_scope():
-          if resource_variable_ops.var_is_initialized_op(v.handle):
-            # Ignore variables which are already initialized at trace time.
+          if is_initialized:
             continue
+
         op_map = lift_to_graph.lift_to_graph(
             [init], ops.get_default_graph(), op_map=op_map)
         v.assign(op_map[init], read_value=False)
diff --git a/tensorflow/python/eager/remote_benchmarks_test.py b/tensorflow/python/eager/remote_benchmarks_test.py
index 458a8b60667..f628557bc1e 100644
--- a/tensorflow/python/eager/remote_benchmarks_test.py
+++ b/tensorflow/python/eager/remote_benchmarks_test.py
@@ -59,6 +59,22 @@ def run_benchmark(func, num_iters, execution_mode=None):
     return end - start
 
 
+class Foo(object):
+
+  def __init__(self, num_vars):
+    self._num_vars = num_vars
+    self._v = []
+
+  def __call__(self, inputs):
+    if not self._v:
+      for _ in range(self._num_vars):
+        self._v.append(variables.Variable(
+            random_ops.random_uniform([]), shape=[]))
+    for v in self._v:
+      inputs = inputs * v
+    return inputs
+
+
 class RemoteWorkerMicroBenchmarks(test.Benchmark):
 
   def __init__(self):
@@ -160,6 +176,26 @@ class RemoteWorkerMicroBenchmarks(test.Benchmark):
     # executed when their corresponding device and manager are still available.
     gc.collect()
 
+  def benchmark_create_vars_inside_function(self):
+    remote.connect_to_remote_host(self._cached_server_target1)
+
+    def func():
+      with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
+        layer = Foo(50)
+
+        @def_function.function
+        def remote_func():
+          with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
+            return layer(random_ops.random_uniform([]))
+
+        return remote_func()
+
+    self._run(func, execution_mode=context.ASYNC, num_iters=100)
+    # NOTE(b/136184459): Force garbage collecting hanging resources before
+    # subsequent calls to set_server_def, to ensure the destroy resource ops are
+    # executed when their corresponding device and manager are still available.
+    gc.collect()
+
 
 if __name__ == "__main__":
   test.main()