TF Eager: Avoid creating some unnecessary zeros during backprop.

PiperOrigin-RevId: 169195496
This commit is contained in:
A. Unique TensorFlower 2017-09-18 22:12:36 -07:00 committed by TensorFlower Gardener
parent dba4bf1e7b
commit f08ec5722b
6 changed files with 40 additions and 10 deletions

View File

@ -237,7 +237,11 @@ def imperative_grad(
for i in range(len(out_gradients)):
if out_gradients[i] is None:
# TODO(apassos) this should be in the right device
out_gradients[i] = array_ops.zeros(*op_trace.output_shape_and_dtype[i])
none_indices = _grad_fn_accepts_none_for_indices.get(
op_trace.op_type, None)
if none_indices is None or i not in none_indices:
out_gradients[i] = array_ops.zeros(
*op_trace.output_shape_and_dtype[i])
else:
out_gradients[i] = _aggregate_grads(out_gradients[i])
@ -335,8 +339,11 @@ def _magic_gradient_function(op_name, attr_tuple, num_inputs,
grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access
if grad_fn is None:
return [None] * num_inputs
none_indices = _grad_fn_accepts_none_for_indices.get(op_name, [])
out_grads = [
o if (o is not None) else array_ops.zeros_like(outputs[i])
o if (o is not None or i in none_indices)
else array_ops.zeros_like(outputs[i])
for i, o in enumerate(out_grads)
]
return grad_fn(mock_op, *out_grads)
@ -452,6 +459,24 @@ _ops_which_dont_need_inputs = set([
])
# TODO(agarwal): use an automatic mechanism for handling None arguments to
# gradient functions.
# Some gradient functions can accept None arguments for gradients. The following
# maps the operation name to the indices at which the corresponding gradient
# function can accept None values.
# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
# during backprop. However the gradient function uses only the first of those
# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
# indicates that only the gradient corresponding to index 0 is used, and the
# gradient values at indices 1-4 are ignored (and hence can be None). The
# backprop algorithm can then leverage this by not constructing zeros to
# pass for those indices.
_grad_fn_accepts_none_for_indices = {
"SoftmaxCrossEntropyWithLogits": [1],
"FusedBatchNorm": [1, 2, 3, 4]
}
def _record_gradient(op_name, inputs, attrs, results, name):
"""Records gradients for a TensorFlow operation.
@ -500,7 +525,7 @@ def _record_gradient(op_name, inputs, attrs, results, name):
return result
inputs = [ops.convert_to_tensor(x) for x in inputs]
tape.record_operation(results, inputs, [], grad_fn)
tape.record_operation(op_name, results, inputs, [], grad_fn)
if _tracing:
print("Computed op", (name if name else op_name), "inputs", inputs,
"outputs", results)

View File

@ -55,6 +55,7 @@ def custom_gradient(f):
flat_result = nest.flatten(result)
tape.record_operation(
f.__name__,
flat_result,
input_tensors,
[],

View File

@ -86,7 +86,8 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
tensor_map[ops.tensor_id(value)] = (value, captured_value)
else:
captured_value = captured_value[1]
tape.record_operation([captured_value], [value], [], lambda x: x)
tape.record_operation("captured_value", [captured_value], [value], [],
lambda x: x)
return captured_value
@ -262,6 +263,7 @@ class _GraphModeFunction(object):
side_outputs = outputs[len(self._returns):]
tape.record_operation(
signature.name,
real_outputs,
(args + self._extra_inputs),
side_outputs,

View File

@ -31,6 +31,7 @@ def tid(tensor):
class TapeEntry(
collections.namedtuple("TapeEntry", [
"op_type",
"output_ids", "input_ids", "side_outputs", "backward_function",
"output_shape_and_dtype",
])):
@ -96,8 +97,8 @@ class Tape(object):
self._watched_variables.add(v)
self.watch(v.handle)
def record_operation(self, output_tensors, input_tensors, side_outputs,
backward_function):
def record_operation(self, op_type, output_tensors, input_tensors,
side_outputs, backward_function):
"""Records an operation in the tape."""
if not self.should_record(input_tensors):
return output_tensors
@ -109,6 +110,7 @@ class Tape(object):
i = tid(t)
self._tensor_usage[i] = self._tensor_usage.get(i, 0) + 1
self._op_tape[self._next_op_id] = TapeEntry(
op_type,
[tid(t) for t in output_tensors],
[tid(t) for t in input_tensors],
side_outputs,
@ -225,11 +227,11 @@ def should_record(tensors):
return any(x.should_record(tensors) for x in _tape_stack.stack)
def record_operation(output_tensors, input_tensors, side_outputs,
def record_operation(op_type, output_tensors, input_tensors, side_outputs,
backward_function):
"""Records the operation on all tapes in the stack."""
for t in _tape_stack.stack:
t.record_operation(output_tensors,
t.record_operation(op_type, output_tensors,
input_tensors,
side_outputs,
backward_function)

View File

@ -769,7 +769,7 @@ class EagerTensor(Tensor):
grad_h = c_api.TFE_TensorHandleCopyToDevice(
dresult._handle, ctx._handle, self_device, status)
return _tensor_from_handle(grad_h)
tape.record_operation([new_tensor], [self], [], grad_fun)
tape.record_operation("_copy", [new_tensor], [self], [], grad_fun)
return new_tensor
# pylint: enable=protected-access

View File

@ -429,7 +429,7 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
const_fill_value = tensor_util.constant_value(g)
return const_fill_value is not None and (const_fill_value == 0).all()
if not IsZero(grad_grad):
if grad_grad is not None and not IsZero(grad_grad):
logits = op.inputs[0]
softmax = nn_ops.softmax(logits)