TF Eager: Avoid creating some unnecessary zeros during backprop.
PiperOrigin-RevId: 169195496
This commit is contained in:
parent
dba4bf1e7b
commit
f08ec5722b
@ -237,7 +237,11 @@ def imperative_grad(
|
||||
for i in range(len(out_gradients)):
|
||||
if out_gradients[i] is None:
|
||||
# TODO(apassos) this should be in the right device
|
||||
out_gradients[i] = array_ops.zeros(*op_trace.output_shape_and_dtype[i])
|
||||
none_indices = _grad_fn_accepts_none_for_indices.get(
|
||||
op_trace.op_type, None)
|
||||
if none_indices is None or i not in none_indices:
|
||||
out_gradients[i] = array_ops.zeros(
|
||||
*op_trace.output_shape_and_dtype[i])
|
||||
else:
|
||||
out_gradients[i] = _aggregate_grads(out_gradients[i])
|
||||
|
||||
@ -335,8 +339,11 @@ def _magic_gradient_function(op_name, attr_tuple, num_inputs,
|
||||
grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access
|
||||
if grad_fn is None:
|
||||
return [None] * num_inputs
|
||||
|
||||
none_indices = _grad_fn_accepts_none_for_indices.get(op_name, [])
|
||||
out_grads = [
|
||||
o if (o is not None) else array_ops.zeros_like(outputs[i])
|
||||
o if (o is not None or i in none_indices)
|
||||
else array_ops.zeros_like(outputs[i])
|
||||
for i, o in enumerate(out_grads)
|
||||
]
|
||||
return grad_fn(mock_op, *out_grads)
|
||||
@ -452,6 +459,24 @@ _ops_which_dont_need_inputs = set([
|
||||
])
|
||||
|
||||
|
||||
# TODO(agarwal): use an automatic mechanism for handling None arguments to
|
||||
# gradient functions.
|
||||
# Some gradient functions can accept None arguments for gradients. The following
|
||||
# maps the operation name to the indices at which the corresponding gradient
|
||||
# function can accept None values.
|
||||
# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
|
||||
# during backprop. However the gradient function uses only the first of those
|
||||
# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
|
||||
# indicates that only the gradient corresponding to index 0 is used, and the
|
||||
# gradient values at indices 1-4 are ignored (and hence can be None). The
|
||||
# backprop algorithm can then leverage this by not constructing zeros to
|
||||
# pass for those indices.
|
||||
_grad_fn_accepts_none_for_indices = {
|
||||
"SoftmaxCrossEntropyWithLogits": [1],
|
||||
"FusedBatchNorm": [1, 2, 3, 4]
|
||||
}
|
||||
|
||||
|
||||
def _record_gradient(op_name, inputs, attrs, results, name):
|
||||
"""Records gradients for a TensorFlow operation.
|
||||
|
||||
@ -500,7 +525,7 @@ def _record_gradient(op_name, inputs, attrs, results, name):
|
||||
return result
|
||||
|
||||
inputs = [ops.convert_to_tensor(x) for x in inputs]
|
||||
tape.record_operation(results, inputs, [], grad_fn)
|
||||
tape.record_operation(op_name, results, inputs, [], grad_fn)
|
||||
if _tracing:
|
||||
print("Computed op", (name if name else op_name), "inputs", inputs,
|
||||
"outputs", results)
|
||||
|
@ -55,6 +55,7 @@ def custom_gradient(f):
|
||||
|
||||
flat_result = nest.flatten(result)
|
||||
tape.record_operation(
|
||||
f.__name__,
|
||||
flat_result,
|
||||
input_tensors,
|
||||
[],
|
||||
|
@ -86,7 +86,8 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
|
||||
tensor_map[ops.tensor_id(value)] = (value, captured_value)
|
||||
else:
|
||||
captured_value = captured_value[1]
|
||||
tape.record_operation([captured_value], [value], [], lambda x: x)
|
||||
tape.record_operation("captured_value", [captured_value], [value], [],
|
||||
lambda x: x)
|
||||
return captured_value
|
||||
|
||||
|
||||
@ -262,6 +263,7 @@ class _GraphModeFunction(object):
|
||||
side_outputs = outputs[len(self._returns):]
|
||||
|
||||
tape.record_operation(
|
||||
signature.name,
|
||||
real_outputs,
|
||||
(args + self._extra_inputs),
|
||||
side_outputs,
|
||||
|
@ -31,6 +31,7 @@ def tid(tensor):
|
||||
|
||||
class TapeEntry(
|
||||
collections.namedtuple("TapeEntry", [
|
||||
"op_type",
|
||||
"output_ids", "input_ids", "side_outputs", "backward_function",
|
||||
"output_shape_and_dtype",
|
||||
])):
|
||||
@ -96,8 +97,8 @@ class Tape(object):
|
||||
self._watched_variables.add(v)
|
||||
self.watch(v.handle)
|
||||
|
||||
def record_operation(self, output_tensors, input_tensors, side_outputs,
|
||||
backward_function):
|
||||
def record_operation(self, op_type, output_tensors, input_tensors,
|
||||
side_outputs, backward_function):
|
||||
"""Records an operation in the tape."""
|
||||
if not self.should_record(input_tensors):
|
||||
return output_tensors
|
||||
@ -109,6 +110,7 @@ class Tape(object):
|
||||
i = tid(t)
|
||||
self._tensor_usage[i] = self._tensor_usage.get(i, 0) + 1
|
||||
self._op_tape[self._next_op_id] = TapeEntry(
|
||||
op_type,
|
||||
[tid(t) for t in output_tensors],
|
||||
[tid(t) for t in input_tensors],
|
||||
side_outputs,
|
||||
@ -225,11 +227,11 @@ def should_record(tensors):
|
||||
return any(x.should_record(tensors) for x in _tape_stack.stack)
|
||||
|
||||
|
||||
def record_operation(output_tensors, input_tensors, side_outputs,
|
||||
def record_operation(op_type, output_tensors, input_tensors, side_outputs,
|
||||
backward_function):
|
||||
"""Records the operation on all tapes in the stack."""
|
||||
for t in _tape_stack.stack:
|
||||
t.record_operation(output_tensors,
|
||||
t.record_operation(op_type, output_tensors,
|
||||
input_tensors,
|
||||
side_outputs,
|
||||
backward_function)
|
||||
|
@ -769,7 +769,7 @@ class EagerTensor(Tensor):
|
||||
grad_h = c_api.TFE_TensorHandleCopyToDevice(
|
||||
dresult._handle, ctx._handle, self_device, status)
|
||||
return _tensor_from_handle(grad_h)
|
||||
tape.record_operation([new_tensor], [self], [], grad_fun)
|
||||
tape.record_operation("_copy", [new_tensor], [self], [], grad_fun)
|
||||
return new_tensor
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
@ -429,7 +429,7 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
|
||||
const_fill_value = tensor_util.constant_value(g)
|
||||
return const_fill_value is not None and (const_fill_value == 0).all()
|
||||
|
||||
if not IsZero(grad_grad):
|
||||
if grad_grad is not None and not IsZero(grad_grad):
|
||||
logits = op.inputs[0]
|
||||
softmax = nn_ops.softmax(logits)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user