diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index ac88f18eba5..0b434621121 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -237,7 +237,11 @@ def imperative_grad( for i in range(len(out_gradients)): if out_gradients[i] is None: # TODO(apassos) this should be in the right device - out_gradients[i] = array_ops.zeros(*op_trace.output_shape_and_dtype[i]) + none_indices = _grad_fn_accepts_none_for_indices.get( + op_trace.op_type, None) + if none_indices is None or i not in none_indices: + out_gradients[i] = array_ops.zeros( + *op_trace.output_shape_and_dtype[i]) else: out_gradients[i] = _aggregate_grads(out_gradients[i]) @@ -335,8 +339,11 @@ def _magic_gradient_function(op_name, attr_tuple, num_inputs, grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access if grad_fn is None: return [None] * num_inputs + + none_indices = _grad_fn_accepts_none_for_indices.get(op_name, []) out_grads = [ - o if (o is not None) else array_ops.zeros_like(outputs[i]) + o if (o is not None or i in none_indices) + else array_ops.zeros_like(outputs[i]) for i, o in enumerate(out_grads) ] return grad_fn(mock_op, *out_grads) @@ -452,6 +459,24 @@ _ops_which_dont_need_inputs = set([ ]) +# TODO(agarwal): use an automatic mechanism for handling None arguments to +# gradient functions. +# Some gradient functions can accept None arguments for gradients. The following +# maps the operation name to the indices at which the corresponding gradient +# function can accept None values. +# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values +# during backprop. However the gradient function uses only the first of those +# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4], +# indicates that only the gradient corresponding to index 0 is used, and the +# gradient values at indices 1-4 are ignored (and hence can be None). The +# backprop algorithm can then leverage this by not constructing zeros to +# pass for those indices. +_grad_fn_accepts_none_for_indices = { + "SoftmaxCrossEntropyWithLogits": [1], + "FusedBatchNorm": [1, 2, 3, 4] +} + + def _record_gradient(op_name, inputs, attrs, results, name): """Records gradients for a TensorFlow operation. @@ -500,7 +525,7 @@ def _record_gradient(op_name, inputs, attrs, results, name): return result inputs = [ops.convert_to_tensor(x) for x in inputs] - tape.record_operation(results, inputs, [], grad_fn) + tape.record_operation(op_name, results, inputs, [], grad_fn) if _tracing: print("Computed op", (name if name else op_name), "inputs", inputs, "outputs", results) diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py index dbd5a509244..6d0634e140e 100644 --- a/tensorflow/python/eager/custom_gradient.py +++ b/tensorflow/python/eager/custom_gradient.py @@ -55,6 +55,7 @@ def custom_gradient(f): flat_result = nest.flatten(result) tape.record_operation( + f.__name__, flat_result, input_tensors, [], diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 3d51b53391a..b8ce8e74632 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -86,7 +86,8 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] - tape.record_operation([captured_value], [value], [], lambda x: x) + tape.record_operation("captured_value", [captured_value], [value], [], + lambda x: x) return captured_value @@ -262,6 +263,7 @@ class _GraphModeFunction(object): side_outputs = outputs[len(self._returns):] tape.record_operation( + signature.name, real_outputs, (args + self._extra_inputs), side_outputs, diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 7ba7d0e7ec9..e4fdaa111a0 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -31,6 +31,7 @@ def tid(tensor): class TapeEntry( collections.namedtuple("TapeEntry", [ + "op_type", "output_ids", "input_ids", "side_outputs", "backward_function", "output_shape_and_dtype", ])): @@ -96,8 +97,8 @@ class Tape(object): self._watched_variables.add(v) self.watch(v.handle) - def record_operation(self, output_tensors, input_tensors, side_outputs, - backward_function): + def record_operation(self, op_type, output_tensors, input_tensors, + side_outputs, backward_function): """Records an operation in the tape.""" if not self.should_record(input_tensors): return output_tensors @@ -109,6 +110,7 @@ class Tape(object): i = tid(t) self._tensor_usage[i] = self._tensor_usage.get(i, 0) + 1 self._op_tape[self._next_op_id] = TapeEntry( + op_type, [tid(t) for t in output_tensors], [tid(t) for t in input_tensors], side_outputs, @@ -225,11 +227,11 @@ def should_record(tensors): return any(x.should_record(tensors) for x in _tape_stack.stack) -def record_operation(output_tensors, input_tensors, side_outputs, +def record_operation(op_type, output_tensors, input_tensors, side_outputs, backward_function): """Records the operation on all tapes in the stack.""" for t in _tape_stack.stack: - t.record_operation(output_tensors, + t.record_operation(op_type, output_tensors, input_tensors, side_outputs, backward_function) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 31ef1b0eb76..f24b3ba6055 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -769,7 +769,7 @@ class EagerTensor(Tensor): grad_h = c_api.TFE_TensorHandleCopyToDevice( dresult._handle, ctx._handle, self_device, status) return _tensor_from_handle(grad_h) - tape.record_operation([new_tensor], [self], [], grad_fun) + tape.record_operation("_copy", [new_tensor], [self], [], grad_fun) return new_tensor # pylint: enable=protected-access diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 6dd8fb31ca7..54627b6fd91 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -429,7 +429,7 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): const_fill_value = tensor_util.constant_value(g) return const_fill_value is not None and (const_fill_value == 0).all() - if not IsZero(grad_grad): + if grad_grad is not None and not IsZero(grad_grad): logits = op.inputs[0] softmax = nn_ops.softmax(logits)