Certain ops don't need eager gradients to keep their inputs / outputs alive.

PiperOrigin-RevId: 168864350
This commit is contained in:
Alexandre Passos 2017-09-15 11:40:08 -07:00 committed by TensorFlower Gardener
parent d8a02b0919
commit b4224fe34e
5 changed files with 142 additions and 15 deletions

View File

@ -349,6 +349,109 @@ _gradient_functions_lock = threading.Lock()
_tracing = False
# TODO(apassos) replace this with a mechanism which can happen at the op
# gradient function registration site, to be less error-prone
# TODO(apassos) add ops other than those in nn_grad and math_grad
_ops_which_dont_need_outputs = set([
"MatMul",
"Conv2DBackpropInput",
"Conv2DBackpropFilter",
"Conv3D",
"Conv3DBackpropInputV2",
"AvgPool3D",
"AvgPool3DGrad",
"MaxPool3D",
"MaxPool3DGrad",
"MaxPool3DGradGrad",
"BiasAdd",
"BiasAddV1",
"BiasAddGrad",
"Relu6",
"Softplus",
"SoftplusGrad",
"Softsign",
"ReluGrad",
"Conv2D",
"DepthwiseConv2dNative",
"Dilation2D",
"AvgPool",
"AvgPoolGrad",
"BatchNormWithGlobalNormalization",
"L2Loss",
"Sum",
"Prod",
"SegmentSum",
"SegmentMean",
"SparseSegmentSum",
"SparseSegmentMean",
"SparseSegmentSqrtN",
"SegmentMin",
"SegmentMax",
"UnsortedSegmentSum",
"UnsortedSegmentMax",
"Abs",
"Neg",
"ReciprocalGrad",
"Square",
"Expm1",
"Log",
"Log1p",
"TanhGrad",
"SigmoidGrad",
"Sign",
"Sin",
"Cos",
"Tan",
"Add",
"Sub",
"Mul",
"Div",
"RealDiv",
"Pow",
"Maximum",
"Minimum",
"SquaredDifference",
"Select",
"SparseMatMul",
"BatchMatMul",
"Complex",
"Real",
"Imag",
"Angle",
"Conj",
"Cast",
"Cross",
"Cumsum",
"Cumprod",
"ReadVariableOp",
"VarHandleOp",
"Shape",
])
_ops_which_dont_need_inputs = set([
"Softmax",
"LogSoftmax",
"BiasAdd",
"Relu",
"Elu",
"Selu",
"SparseSoftmaxCrossEntropyWithLogits",
"Neg",
"Inv",
"Reciprocal",
"Sqrt",
"Exp",
"Tanh",
"Sigmoid",
"Real",
"Imag",
"Conj",
"ReadVariableOp",
"VarHandleOp",
"Shape",
])
def _record_gradient(op_name, inputs, attrs, results, name):
"""Records gradients for a TensorFlow operation.
@ -367,13 +470,32 @@ def _record_gradient(op_name, inputs, attrs, results, name):
Raises:
An exception on error.
"""
if not tape.could_possibly_record():
return
if op_name in _ops_which_dont_need_outputs:
op_outputs = None
else:
# TODO(apassos) this line creates a weak circular reference where the
# backprop function keeps an output alive which in turn keeps the tape entry
# alive which keeps the backprop function alive. Figure out how to break
# this up without breaking second derivatives of ops like Exp whose
# gradients depend only on the outputs.
op_outputs = results
if op_name in _ops_which_dont_need_inputs:
op_inputs = None
else:
op_inputs = inputs
num_inputs = len(inputs)
def grad_fn(*orig_outputs):
"""Generated gradient function."""
result = _magic_gradient_function(op_name, attrs, len(inputs),
inputs, results, orig_outputs)
result = _magic_gradient_function(op_name, attrs, num_inputs,
op_inputs, op_outputs, orig_outputs)
if _tracing:
print("Gradient for", (name if name else op_name), "inputs", inputs,
print("Gradient for", (name if name else op_name), "inputs", op_inputs,
"output_grads", orig_outputs, "gradients", result)
return result

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import memory_trace
from tensorflow.python.framework import errors
@ -57,8 +56,7 @@ def enable_tracing():
WARNING: tracing is not thread-safe.
"""
global _active_trace
_active_trace = memory_trace.MemoryTrace(
len(context.get_default_context().devices()))
_active_trace = memory_trace.MemoryTrace()
def flush_trace():

View File

@ -73,12 +73,11 @@ def execute(op_name, num_outputs, inputs, attrs=None, name=None):
tensors = [ops._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
# TODO(alive, cais): Use the execution callback mechanism.
if core.active_trace() is not None:
trace_name = name if name else op_name
for t in tensors:
# pylint: disable=protected-access
core.active_trace().record_tensor(trace_name,
core.active_trace().record_tensor(op_name,
ops.tensor_id(t),
t._device_name(),
t.device,
t.shape.num_elements())
# pylint: enable=protected-access

View File

@ -29,29 +29,30 @@ TensorData = collections.namedtuple(
class MemoryTrace(object):
"""Records a trace of memory usage over operation execution."""
def __init__(self, n_devices):
def __init__(self):
self.trace = []
self.tensor_to_data = {}
self.current_device_mem_usage = [0] * n_devices
self.current_device_mem_usage = collections.defaultdict(int)
def record_tensor(self, op_name, tensor_id, device, size):
self.current_device_mem_usage[device] += size
self.tensor_to_data[tensor_id] = TensorData(op_name, size, device)
self.trace.append(TraceEntry(op_name,
tensor_id,
self.current_device_mem_usage[:],
dict(self.current_device_mem_usage.items()),
device,
size))
def delete_tensor(self, tensor_id):
if tensor_id not in self.tensor_to_data:
return
data = self.tensor_to_data.pop(tensor_id)
data = self.tensor_to_data.pop(tensor_id, None)
if data is None: return
self.current_device_mem_usage[data.device] -= data.tensor_size
self.trace.append(TraceEntry(data.op_name,
tensor_id,
self.current_device_mem_usage[:],
dict(self.current_device_mem_usage.items()),
data.device,
-data.tensor_size))

View File

@ -121,7 +121,9 @@ class Tape(object):
self._tensor_usage[i] -= 1
if self._tensor_usage[i] == 0:
del self._tensor_usage[i]
op_id = self._tensor_tape.pop(i)
op_id = self._tensor_tape.pop(i, None)
if op_id is None:
return
op = self._op_tape[op_id]
if not any(tensor_id in self._tensor_usage
for tensor_id in op.output_ids):
@ -247,3 +249,8 @@ def top_tape_watched_tensors():
def top_tape_watched_variables():
t = _tape_stack.stack[-1]
return t._watched_variables # pylint: disable=protected-access
def could_possibly_record():
"""Returns True if any tape is active."""
return len(_tape_stack.stack) > 0 # pylint: disable=g-explicit-length-test