diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 44a4af182cc..3a0da779d3e 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import function as function_lib from tensorflow.python.eager import lift_to_graph from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -695,15 +696,28 @@ class Function(object): def _initialize_uninitialized_variables(self, initializers): """Make and call a `ConcreteFunction` which initializes variables.""" + if not initializers: + return + # Note: using defun here avoids an infinite recursion. @function_lib.defun def initialize_variables(): op_map = object_identity.ObjectIdentityDictionary() - for v, init in initializers: + # Stack all the var_is_initialized values into one tensor and intepret the + # numpy value. This will reduce the number of RPCs between client and + # worker in the remote case. + with ops.init_scope(): + var_is_initialized = [] + for v, _ in initializers: + var_is_initialized.append( + resource_variable_ops.var_is_initialized_op(v.handle)) + var_is_initialized = array_ops.stack(var_is_initialized).numpy() + + for (v, init), is_initialized in zip(initializers, var_is_initialized): with ops.init_scope(): - if resource_variable_ops.var_is_initialized_op(v.handle): - # Ignore variables which are already initialized at trace time. + if is_initialized: continue + op_map = lift_to_graph.lift_to_graph( [init], ops.get_default_graph(), op_map=op_map) v.assign(op_map[init], read_value=False) diff --git a/tensorflow/python/eager/remote_benchmarks_test.py b/tensorflow/python/eager/remote_benchmarks_test.py index 458a8b60667..f628557bc1e 100644 --- a/tensorflow/python/eager/remote_benchmarks_test.py +++ b/tensorflow/python/eager/remote_benchmarks_test.py @@ -59,6 +59,22 @@ def run_benchmark(func, num_iters, execution_mode=None): return end - start +class Foo(object): + + def __init__(self, num_vars): + self._num_vars = num_vars + self._v = [] + + def __call__(self, inputs): + if not self._v: + for _ in range(self._num_vars): + self._v.append(variables.Variable( + random_ops.random_uniform([]), shape=[])) + for v in self._v: + inputs = inputs * v + return inputs + + class RemoteWorkerMicroBenchmarks(test.Benchmark): def __init__(self): @@ -160,6 +176,26 @@ class RemoteWorkerMicroBenchmarks(test.Benchmark): # executed when their corresponding device and manager are still available. gc.collect() + def benchmark_create_vars_inside_function(self): + remote.connect_to_remote_host(self._cached_server_target1) + + def func(): + with ops.device("job:worker/replica:0/task:0/device:CPU:0"): + layer = Foo(50) + + @def_function.function + def remote_func(): + with ops.device("job:worker/replica:0/task:0/device:CPU:0"): + return layer(random_ops.random_uniform([])) + + return remote_func() + + self._run(func, execution_mode=context.ASYNC, num_iters=100) + # NOTE(b/136184459): Force garbage collecting hanging resources before + # subsequent calls to set_server_def, to ensure the destroy resource ops are + # executed when their corresponding device and manager are still available. + gc.collect() + if __name__ == "__main__": test.main()