From b4224fe34ec994846ce3166d9735c8036e9dfc4e Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Fri, 15 Sep 2017 11:40:08 -0700 Subject: [PATCH] Certain ops don't need eager gradients to keep their inputs / outputs alive. PiperOrigin-RevId: 168864350 --- tensorflow/python/eager/backprop.py | 128 +++++++++++++++++++++++- tensorflow/python/eager/core.py | 4 +- tensorflow/python/eager/execute.py | 5 +- tensorflow/python/eager/memory_trace.py | 11 +- tensorflow/python/eager/tape.py | 9 +- 5 files changed, 142 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index cc9e269e116..92af5f3edf6 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -349,6 +349,109 @@ _gradient_functions_lock = threading.Lock() _tracing = False +# TODO(apassos) replace this with a mechanism which can happen at the op +# gradient function registration site, to be less error-prone +# TODO(apassos) add ops other than those in nn_grad and math_grad +_ops_which_dont_need_outputs = set([ + "MatMul", + "Conv2DBackpropInput", + "Conv2DBackpropFilter", + "Conv3D", + "Conv3DBackpropInputV2", + "AvgPool3D", + "AvgPool3DGrad", + "MaxPool3D", + "MaxPool3DGrad", + "MaxPool3DGradGrad", + "BiasAdd", + "BiasAddV1", + "BiasAddGrad", + "Relu6", + "Softplus", + "SoftplusGrad", + "Softsign", + "ReluGrad", + "Conv2D", + "DepthwiseConv2dNative", + "Dilation2D", + "AvgPool", + "AvgPoolGrad", + "BatchNormWithGlobalNormalization", + "L2Loss", + "Sum", + "Prod", + "SegmentSum", + "SegmentMean", + "SparseSegmentSum", + "SparseSegmentMean", + "SparseSegmentSqrtN", + "SegmentMin", + "SegmentMax", + "UnsortedSegmentSum", + "UnsortedSegmentMax", + "Abs", + "Neg", + "ReciprocalGrad", + "Square", + "Expm1", + "Log", + "Log1p", + "TanhGrad", + "SigmoidGrad", + "Sign", + "Sin", + "Cos", + "Tan", + "Add", + "Sub", + "Mul", + "Div", + "RealDiv", + "Pow", + "Maximum", + "Minimum", + "SquaredDifference", + "Select", + "SparseMatMul", + "BatchMatMul", + "Complex", + "Real", + "Imag", + "Angle", + "Conj", + "Cast", + "Cross", + "Cumsum", + "Cumprod", + "ReadVariableOp", + "VarHandleOp", + "Shape", +]) + +_ops_which_dont_need_inputs = set([ + "Softmax", + "LogSoftmax", + "BiasAdd", + "Relu", + "Elu", + "Selu", + "SparseSoftmaxCrossEntropyWithLogits", + "Neg", + "Inv", + "Reciprocal", + "Sqrt", + "Exp", + "Tanh", + "Sigmoid", + "Real", + "Imag", + "Conj", + "ReadVariableOp", + "VarHandleOp", + "Shape", +]) + + def _record_gradient(op_name, inputs, attrs, results, name): """Records gradients for a TensorFlow operation. @@ -367,13 +470,32 @@ def _record_gradient(op_name, inputs, attrs, results, name): Raises: An exception on error. """ + if not tape.could_possibly_record(): + return + + if op_name in _ops_which_dont_need_outputs: + op_outputs = None + else: + # TODO(apassos) this line creates a weak circular reference where the + # backprop function keeps an output alive which in turn keeps the tape entry + # alive which keeps the backprop function alive. Figure out how to break + # this up without breaking second derivatives of ops like Exp whose + # gradients depend only on the outputs. + op_outputs = results + + if op_name in _ops_which_dont_need_inputs: + op_inputs = None + else: + op_inputs = inputs + + num_inputs = len(inputs) def grad_fn(*orig_outputs): """Generated gradient function.""" - result = _magic_gradient_function(op_name, attrs, len(inputs), - inputs, results, orig_outputs) + result = _magic_gradient_function(op_name, attrs, num_inputs, + op_inputs, op_outputs, orig_outputs) if _tracing: - print("Gradient for", (name if name else op_name), "inputs", inputs, + print("Gradient for", (name if name else op_name), "inputs", op_inputs, "output_grads", orig_outputs, "gradients", result) return result diff --git a/tensorflow/python/eager/core.py b/tensorflow/python/eager/core.py index 64c615fb63b..b6e7d53ee8b 100644 --- a/tensorflow/python/eager/core.py +++ b/tensorflow/python/eager/core.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow -from tensorflow.python.eager import context from tensorflow.python.eager import memory_trace from tensorflow.python.framework import errors @@ -57,8 +56,7 @@ def enable_tracing(): WARNING: tracing is not thread-safe. """ global _active_trace - _active_trace = memory_trace.MemoryTrace( - len(context.get_default_context().devices())) + _active_trace = memory_trace.MemoryTrace() def flush_trace(): diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 07d37cc5005..50a23f585bd 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -73,12 +73,11 @@ def execute(op_name, num_outputs, inputs, attrs=None, name=None): tensors = [ops._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access # TODO(alive, cais): Use the execution callback mechanism. if core.active_trace() is not None: - trace_name = name if name else op_name for t in tensors: # pylint: disable=protected-access - core.active_trace().record_tensor(trace_name, + core.active_trace().record_tensor(op_name, ops.tensor_id(t), - t._device_name(), + t.device, t.shape.num_elements()) # pylint: enable=protected-access diff --git a/tensorflow/python/eager/memory_trace.py b/tensorflow/python/eager/memory_trace.py index 0baf9224080..094bcab9e2e 100644 --- a/tensorflow/python/eager/memory_trace.py +++ b/tensorflow/python/eager/memory_trace.py @@ -29,29 +29,30 @@ TensorData = collections.namedtuple( class MemoryTrace(object): """Records a trace of memory usage over operation execution.""" - def __init__(self, n_devices): + def __init__(self): self.trace = [] self.tensor_to_data = {} - self.current_device_mem_usage = [0] * n_devices + self.current_device_mem_usage = collections.defaultdict(int) def record_tensor(self, op_name, tensor_id, device, size): self.current_device_mem_usage[device] += size self.tensor_to_data[tensor_id] = TensorData(op_name, size, device) self.trace.append(TraceEntry(op_name, tensor_id, - self.current_device_mem_usage[:], + dict(self.current_device_mem_usage.items()), device, size)) def delete_tensor(self, tensor_id): if tensor_id not in self.tensor_to_data: return - data = self.tensor_to_data.pop(tensor_id) + data = self.tensor_to_data.pop(tensor_id, None) + if data is None: return self.current_device_mem_usage[data.device] -= data.tensor_size self.trace.append(TraceEntry(data.op_name, tensor_id, - self.current_device_mem_usage[:], + dict(self.current_device_mem_usage.items()), data.device, -data.tensor_size)) diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 899325cb20a..7ba7d0e7ec9 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -121,7 +121,9 @@ class Tape(object): self._tensor_usage[i] -= 1 if self._tensor_usage[i] == 0: del self._tensor_usage[i] - op_id = self._tensor_tape.pop(i) + op_id = self._tensor_tape.pop(i, None) + if op_id is None: + return op = self._op_tape[op_id] if not any(tensor_id in self._tensor_usage for tensor_id in op.output_ids): @@ -247,3 +249,8 @@ def top_tape_watched_tensors(): def top_tape_watched_variables(): t = _tape_stack.stack[-1] return t._watched_variables # pylint: disable=protected-access + + +def could_possibly_record(): + """Returns True if any tape is active.""" + return len(_tape_stack.stack) > 0 # pylint: disable=g-explicit-length-test