Certain ops don't need eager gradients to keep their inputs / outputs alive.
PiperOrigin-RevId: 168864350
This commit is contained in:
parent
d8a02b0919
commit
b4224fe34e
@ -349,6 +349,109 @@ _gradient_functions_lock = threading.Lock()
|
||||
_tracing = False
|
||||
|
||||
|
||||
# TODO(apassos) replace this with a mechanism which can happen at the op
|
||||
# gradient function registration site, to be less error-prone
|
||||
# TODO(apassos) add ops other than those in nn_grad and math_grad
|
||||
_ops_which_dont_need_outputs = set([
|
||||
"MatMul",
|
||||
"Conv2DBackpropInput",
|
||||
"Conv2DBackpropFilter",
|
||||
"Conv3D",
|
||||
"Conv3DBackpropInputV2",
|
||||
"AvgPool3D",
|
||||
"AvgPool3DGrad",
|
||||
"MaxPool3D",
|
||||
"MaxPool3DGrad",
|
||||
"MaxPool3DGradGrad",
|
||||
"BiasAdd",
|
||||
"BiasAddV1",
|
||||
"BiasAddGrad",
|
||||
"Relu6",
|
||||
"Softplus",
|
||||
"SoftplusGrad",
|
||||
"Softsign",
|
||||
"ReluGrad",
|
||||
"Conv2D",
|
||||
"DepthwiseConv2dNative",
|
||||
"Dilation2D",
|
||||
"AvgPool",
|
||||
"AvgPoolGrad",
|
||||
"BatchNormWithGlobalNormalization",
|
||||
"L2Loss",
|
||||
"Sum",
|
||||
"Prod",
|
||||
"SegmentSum",
|
||||
"SegmentMean",
|
||||
"SparseSegmentSum",
|
||||
"SparseSegmentMean",
|
||||
"SparseSegmentSqrtN",
|
||||
"SegmentMin",
|
||||
"SegmentMax",
|
||||
"UnsortedSegmentSum",
|
||||
"UnsortedSegmentMax",
|
||||
"Abs",
|
||||
"Neg",
|
||||
"ReciprocalGrad",
|
||||
"Square",
|
||||
"Expm1",
|
||||
"Log",
|
||||
"Log1p",
|
||||
"TanhGrad",
|
||||
"SigmoidGrad",
|
||||
"Sign",
|
||||
"Sin",
|
||||
"Cos",
|
||||
"Tan",
|
||||
"Add",
|
||||
"Sub",
|
||||
"Mul",
|
||||
"Div",
|
||||
"RealDiv",
|
||||
"Pow",
|
||||
"Maximum",
|
||||
"Minimum",
|
||||
"SquaredDifference",
|
||||
"Select",
|
||||
"SparseMatMul",
|
||||
"BatchMatMul",
|
||||
"Complex",
|
||||
"Real",
|
||||
"Imag",
|
||||
"Angle",
|
||||
"Conj",
|
||||
"Cast",
|
||||
"Cross",
|
||||
"Cumsum",
|
||||
"Cumprod",
|
||||
"ReadVariableOp",
|
||||
"VarHandleOp",
|
||||
"Shape",
|
||||
])
|
||||
|
||||
_ops_which_dont_need_inputs = set([
|
||||
"Softmax",
|
||||
"LogSoftmax",
|
||||
"BiasAdd",
|
||||
"Relu",
|
||||
"Elu",
|
||||
"Selu",
|
||||
"SparseSoftmaxCrossEntropyWithLogits",
|
||||
"Neg",
|
||||
"Inv",
|
||||
"Reciprocal",
|
||||
"Sqrt",
|
||||
"Exp",
|
||||
"Tanh",
|
||||
"Sigmoid",
|
||||
"Real",
|
||||
"Imag",
|
||||
"Conj",
|
||||
"ReadVariableOp",
|
||||
"VarHandleOp",
|
||||
"Shape",
|
||||
])
|
||||
|
||||
|
||||
def _record_gradient(op_name, inputs, attrs, results, name):
|
||||
"""Records gradients for a TensorFlow operation.
|
||||
|
||||
@ -367,13 +470,32 @@ def _record_gradient(op_name, inputs, attrs, results, name):
|
||||
Raises:
|
||||
An exception on error.
|
||||
"""
|
||||
if not tape.could_possibly_record():
|
||||
return
|
||||
|
||||
if op_name in _ops_which_dont_need_outputs:
|
||||
op_outputs = None
|
||||
else:
|
||||
# TODO(apassos) this line creates a weak circular reference where the
|
||||
# backprop function keeps an output alive which in turn keeps the tape entry
|
||||
# alive which keeps the backprop function alive. Figure out how to break
|
||||
# this up without breaking second derivatives of ops like Exp whose
|
||||
# gradients depend only on the outputs.
|
||||
op_outputs = results
|
||||
|
||||
if op_name in _ops_which_dont_need_inputs:
|
||||
op_inputs = None
|
||||
else:
|
||||
op_inputs = inputs
|
||||
|
||||
num_inputs = len(inputs)
|
||||
|
||||
def grad_fn(*orig_outputs):
|
||||
"""Generated gradient function."""
|
||||
result = _magic_gradient_function(op_name, attrs, len(inputs),
|
||||
inputs, results, orig_outputs)
|
||||
result = _magic_gradient_function(op_name, attrs, num_inputs,
|
||||
op_inputs, op_outputs, orig_outputs)
|
||||
if _tracing:
|
||||
print("Gradient for", (name if name else op_name), "inputs", inputs,
|
||||
print("Gradient for", (name if name else op_name), "inputs", op_inputs,
|
||||
"output_grads", orig_outputs, "gradients", result)
|
||||
return result
|
||||
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import memory_trace
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
@ -57,8 +56,7 @@ def enable_tracing():
|
||||
WARNING: tracing is not thread-safe.
|
||||
"""
|
||||
global _active_trace
|
||||
_active_trace = memory_trace.MemoryTrace(
|
||||
len(context.get_default_context().devices()))
|
||||
_active_trace = memory_trace.MemoryTrace()
|
||||
|
||||
|
||||
def flush_trace():
|
||||
|
@ -73,12 +73,11 @@ def execute(op_name, num_outputs, inputs, attrs=None, name=None):
|
||||
tensors = [ops._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
|
||||
# TODO(alive, cais): Use the execution callback mechanism.
|
||||
if core.active_trace() is not None:
|
||||
trace_name = name if name else op_name
|
||||
for t in tensors:
|
||||
# pylint: disable=protected-access
|
||||
core.active_trace().record_tensor(trace_name,
|
||||
core.active_trace().record_tensor(op_name,
|
||||
ops.tensor_id(t),
|
||||
t._device_name(),
|
||||
t.device,
|
||||
t.shape.num_elements())
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
@ -29,29 +29,30 @@ TensorData = collections.namedtuple(
|
||||
class MemoryTrace(object):
|
||||
"""Records a trace of memory usage over operation execution."""
|
||||
|
||||
def __init__(self, n_devices):
|
||||
def __init__(self):
|
||||
|
||||
self.trace = []
|
||||
self.tensor_to_data = {}
|
||||
self.current_device_mem_usage = [0] * n_devices
|
||||
self.current_device_mem_usage = collections.defaultdict(int)
|
||||
|
||||
def record_tensor(self, op_name, tensor_id, device, size):
|
||||
self.current_device_mem_usage[device] += size
|
||||
self.tensor_to_data[tensor_id] = TensorData(op_name, size, device)
|
||||
self.trace.append(TraceEntry(op_name,
|
||||
tensor_id,
|
||||
self.current_device_mem_usage[:],
|
||||
dict(self.current_device_mem_usage.items()),
|
||||
device,
|
||||
size))
|
||||
|
||||
def delete_tensor(self, tensor_id):
|
||||
if tensor_id not in self.tensor_to_data:
|
||||
return
|
||||
data = self.tensor_to_data.pop(tensor_id)
|
||||
data = self.tensor_to_data.pop(tensor_id, None)
|
||||
if data is None: return
|
||||
self.current_device_mem_usage[data.device] -= data.tensor_size
|
||||
self.trace.append(TraceEntry(data.op_name,
|
||||
tensor_id,
|
||||
self.current_device_mem_usage[:],
|
||||
dict(self.current_device_mem_usage.items()),
|
||||
data.device,
|
||||
-data.tensor_size))
|
||||
|
||||
|
@ -121,7 +121,9 @@ class Tape(object):
|
||||
self._tensor_usage[i] -= 1
|
||||
if self._tensor_usage[i] == 0:
|
||||
del self._tensor_usage[i]
|
||||
op_id = self._tensor_tape.pop(i)
|
||||
op_id = self._tensor_tape.pop(i, None)
|
||||
if op_id is None:
|
||||
return
|
||||
op = self._op_tape[op_id]
|
||||
if not any(tensor_id in self._tensor_usage
|
||||
for tensor_id in op.output_ids):
|
||||
@ -247,3 +249,8 @@ def top_tape_watched_tensors():
|
||||
def top_tape_watched_variables():
|
||||
t = _tape_stack.stack[-1]
|
||||
return t._watched_variables # pylint: disable=protected-access
|
||||
|
||||
|
||||
def could_possibly_record():
|
||||
"""Returns True if any tape is active."""
|
||||
return len(_tape_stack.stack) > 0 # pylint: disable=g-explicit-length-test
|
||||
|
Loading…
Reference in New Issue
Block a user