Refactor the keras optimizer gradient apply function to reuse a context call and skip name scopes in eager mode where they have no effect. This reduces the Python overhead of applying gradient updates in eager mode.
PiperOrigin-RevId: 283867294 Change-Id: I8d61428b79d377c3f0ff724a56aaffdb795865ba
This commit is contained in:
parent
ca86f6cfe9
commit
30b99dd053
@ -474,18 +474,17 @@ class OptimizerV2(trackable.Trackable):
|
||||
else:
|
||||
return update_op
|
||||
|
||||
eagerly_outside_functions = ops.executing_eagerly_outside_functions()
|
||||
update_ops = []
|
||||
with backend.name_scope(name or self._name):
|
||||
with ops.name_scope(name or self._name, skip_on_eager=True):
|
||||
for grad, var in grads_and_vars:
|
||||
scope_name = ("update" if ops.executing_eagerly_outside_functions() else
|
||||
"update_" + var.op.name)
|
||||
# Colocate the update with variables to avoid unnecessary communication
|
||||
# delays. See b/136304694.
|
||||
with backend.name_scope(
|
||||
scope_name), distribution.extended.colocate_vars_with(var):
|
||||
update_ops.extend(
|
||||
distribution.extended.update(
|
||||
var, apply_grad_to_update_var, args=(grad,), group=False))
|
||||
with distribution.extended.colocate_vars_with(var):
|
||||
with ops.name_scope("update" if eagerly_outside_functions else
|
||||
"update_" + var.op.name, skip_on_eager=True):
|
||||
update_ops.extend(distribution.extended.update(
|
||||
var, apply_grad_to_update_var, args=(grad,), group=False))
|
||||
|
||||
any_symbolic = any(isinstance(i, ops.Operation) or
|
||||
tf_utils.is_symbolic_tensor(i) for i in update_ops)
|
||||
|
Loading…
Reference in New Issue
Block a user