Refactor the keras optimizer gradient apply function to reuse a context call and skip name scopes in eager mode where they have no effect. This reduces the Python overhead of applying gradient updates in eager mode.
PiperOrigin-RevId: 283867294 Change-Id: I8d61428b79d377c3f0ff724a56aaffdb795865ba
This commit is contained in:
parent
ca86f6cfe9
commit
30b99dd053
@ -474,17 +474,16 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
else:
|
else:
|
||||||
return update_op
|
return update_op
|
||||||
|
|
||||||
|
eagerly_outside_functions = ops.executing_eagerly_outside_functions()
|
||||||
update_ops = []
|
update_ops = []
|
||||||
with backend.name_scope(name or self._name):
|
with ops.name_scope(name or self._name, skip_on_eager=True):
|
||||||
for grad, var in grads_and_vars:
|
for grad, var in grads_and_vars:
|
||||||
scope_name = ("update" if ops.executing_eagerly_outside_functions() else
|
|
||||||
"update_" + var.op.name)
|
|
||||||
# Colocate the update with variables to avoid unnecessary communication
|
# Colocate the update with variables to avoid unnecessary communication
|
||||||
# delays. See b/136304694.
|
# delays. See b/136304694.
|
||||||
with backend.name_scope(
|
with distribution.extended.colocate_vars_with(var):
|
||||||
scope_name), distribution.extended.colocate_vars_with(var):
|
with ops.name_scope("update" if eagerly_outside_functions else
|
||||||
update_ops.extend(
|
"update_" + var.op.name, skip_on_eager=True):
|
||||||
distribution.extended.update(
|
update_ops.extend(distribution.extended.update(
|
||||||
var, apply_grad_to_update_var, args=(grad,), group=False))
|
var, apply_grad_to_update_var, args=(grad,), group=False))
|
||||||
|
|
||||||
any_symbolic = any(isinstance(i, ops.Operation) or
|
any_symbolic = any(isinstance(i, ops.Operation) or
|
||||||
|
Loading…
Reference in New Issue
Block a user