Improves function tracing time of variables created inside functions. In _initialize_uninitialized_variables, pack all variable_is_initialized ops into one tensor and convert it to numpy in one call, instead of checking remote variable_is_initialized op for each variable which introduces a lot of blocking RPCs.

PiperOrigin-RevId: 272529466
This commit is contained in:
Ruoxin Sang 2019-10-02 15:10:56 -07:00 committed by TensorFlower Gardener
parent 7bdc261c65
commit 66f1ac58e4
2 changed files with 53 additions and 3 deletions

View File

@ -28,6 +28,7 @@ from tensorflow.python.eager import function as function_lib
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@ -695,15 +696,28 @@ class Function(object):
def _initialize_uninitialized_variables(self, initializers):
"""Make and call a `ConcreteFunction` which initializes variables."""
if not initializers:
return
# Note: using defun here avoids an infinite recursion.
@function_lib.defun
def initialize_variables():
op_map = object_identity.ObjectIdentityDictionary()
for v, init in initializers:
# Stack all the var_is_initialized values into one tensor and intepret the
# numpy value. This will reduce the number of RPCs between client and
# worker in the remote case.
with ops.init_scope():
var_is_initialized = []
for v, _ in initializers:
var_is_initialized.append(
resource_variable_ops.var_is_initialized_op(v.handle))
var_is_initialized = array_ops.stack(var_is_initialized).numpy()
for (v, init), is_initialized in zip(initializers, var_is_initialized):
with ops.init_scope():
if resource_variable_ops.var_is_initialized_op(v.handle):
# Ignore variables which are already initialized at trace time.
if is_initialized:
continue
op_map = lift_to_graph.lift_to_graph(
[init], ops.get_default_graph(), op_map=op_map)
v.assign(op_map[init], read_value=False)

View File

@ -59,6 +59,22 @@ def run_benchmark(func, num_iters, execution_mode=None):
return end - start
class Foo(object):
def __init__(self, num_vars):
self._num_vars = num_vars
self._v = []
def __call__(self, inputs):
if not self._v:
for _ in range(self._num_vars):
self._v.append(variables.Variable(
random_ops.random_uniform([]), shape=[]))
for v in self._v:
inputs = inputs * v
return inputs
class RemoteWorkerMicroBenchmarks(test.Benchmark):
def __init__(self):
@ -160,6 +176,26 @@ class RemoteWorkerMicroBenchmarks(test.Benchmark):
# executed when their corresponding device and manager are still available.
gc.collect()
def benchmark_create_vars_inside_function(self):
remote.connect_to_remote_host(self._cached_server_target1)
def func():
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
layer = Foo(50)
@def_function.function
def remote_func():
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
return layer(random_ops.random_uniform([]))
return remote_func()
self._run(func, execution_mode=context.ASYNC, num_iters=100)
# NOTE(b/136184459): Force garbage collecting hanging resources before
# subsequent calls to set_server_def, to ensure the destroy resource ops are
# executed when their corresponding device and manager are still available.
gc.collect()
if __name__ == "__main__":
test.main()