Use the name scope of the forward pass in the backward pass.

PiperOrigin-RevId: 294732289
Change-Id: I15b20a033f66f4d4201b0eb6ea9be8b6cd82bf56
This commit is contained in:
Jiho Choi 2020-02-12 12:48:34 -08:00 committed by TensorFlower Gardener
parent 7e2063e88f
commit 77e9ffb9b2
7 changed files with 65 additions and 16 deletions

View File

@ -514,6 +514,7 @@ py_library(
":tape",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",

View File

@ -38,6 +38,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
@ -123,7 +124,7 @@ class _MockOp(object):
def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
out_grads, skip_input_indices):
out_grads, skip_input_indices, forward_pass_name_scope):
"""Calls the gradient function of the op.
Args:
@ -135,6 +136,7 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
out_grads: gradients of the operation wrt its outputs.
skip_input_indices: a tuple that is passed to the gradient function,
indicating which inputs to skip calculating the gradient for
forward_pass_name_scope: the namescope of the op in the forward pass.
Returns:
The gradients with respect to the inputs of the function, as a list.
@ -144,7 +146,17 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
if grad_fn is None:
return [None] * num_inputs
return grad_fn(mock_op, *out_grads)
# This does not work with v1 TensorArrays.
if ops.executing_eagerly_outside_functions(
) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
if forward_pass_name_scope:
gradient_name_scope = "gradient_tape/" + forward_pass_name_scope + "/"
else:
gradient_name_scope = "gradient_tape/"
with ops.name_scope(gradient_name_scope):
return grad_fn(mock_op, *out_grads)
else:
return grad_fn(mock_op, *out_grads)
pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
@ -155,7 +167,8 @@ def _must_record_gradient():
def _record_gradient(op_name, inputs, attrs, results):
return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results)
return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results,
ops.get_name_scope())
execute.must_record_gradient = _must_record_gradient

View File

@ -1512,6 +1512,27 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
tape.gradient(z, z)
self.assertEqual((z,), tape.watched_variables())
def testNameScope(self):
def fn(x):
with ops.name_scope('my_scope'):
a = math_ops.cos(x)
b = math_ops.cos(x)
return math_ops.add(a, b)
@function.defun
def grad_fn(x):
return backprop.gradients_function(fn)(x)
grad_ops = grad_fn.get_concrete_function(
constant_op.constant(1.0)).graph.get_operations()
num_sin_ops_found = 0
for op in grad_ops:
if op.type == 'Sin':
num_sin_ops_found += 1
self.assertIn('gradient_tape/my_scope/', op.name)
self.assertEqual(num_sin_ops_found, 2)
class JacobianTest(test.TestCase):
def _jacobian(self, experimental_use_pfor):

View File

@ -275,7 +275,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args);
// Record the gradient for a given op.
PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
PyObject* attrs, PyObject* results);
PyObject* attrs, PyObject* results,
PyObject* forward_pass_name_scope);
// Returns all variables watched by the given tape in the order those variables
// were created.

View File

@ -2916,7 +2916,8 @@ PyObject* CopySequenceSettingIndicesToNull(
}
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
PyObject* results) {
PyObject* results,
PyObject* forward_pass_name_scope = nullptr) {
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
if (PyErr_Occurred()) return nullptr;
std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
@ -2997,16 +2998,21 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
TapeSetRecordOperation(
op_name, inputs, results, input_ids, input_dtypes,
[op_name, attrs, num_inputs, op_inputs, op_outputs]() {
[op_name, attrs, num_inputs, op_inputs, op_outputs,
forward_pass_name_scope]() {
Py_INCREF(op_name);
Py_INCREF(attrs);
Py_INCREF(num_inputs);
Py_INCREF(op_inputs);
Py_INCREF(op_outputs);
Py_INCREF(forward_pass_name_scope);
PyBackwardFunction* function = new PyBackwardFunction(
[op_name, attrs, num_inputs, op_inputs, op_outputs](
[op_name, attrs, num_inputs, op_inputs, op_outputs,
forward_pass_name_scope](
PyObject* output_grads,
const std::vector<tensorflow::int64>& unneeded_gradients) {
if (PyErr_Occurred()) {
@ -3026,8 +3032,9 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
skip_input_indices.reset(Py_None);
}
tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
"OOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
output_grads, skip_input_indices.get()));
"OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
output_grads, skip_input_indices.get(),
forward_pass_name_scope));
tensorflow::Safe_PyObjectPtr result(
PyObject_CallObject(gradient_function, callback_args.get()));
@ -3038,13 +3045,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
});
return function;
},
[op_name, attrs, num_inputs, op_inputs,
op_outputs](PyBackwardFunction* backward_function) {
[op_name, attrs, num_inputs, op_inputs, op_outputs,
forward_pass_name_scope](PyBackwardFunction* backward_function) {
Py_DECREF(op_name);
Py_DECREF(attrs);
Py_DECREF(num_inputs);
Py_DECREF(op_inputs);
Py_DECREF(op_outputs);
Py_DECREF(forward_pass_name_scope);
delete backward_function;
},
@ -3668,12 +3676,14 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
}
PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
PyObject* attrs, PyObject* results) {
PyObject* attrs, PyObject* results,
PyObject* forward_pass_name_scope) {
if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
Py_RETURN_NONE;
}
return RecordGradient(op_name, inputs, attrs, results);
return RecordGradient(op_name, inputs, attrs, results,
forward_pass_name_scope);
}
namespace {

View File

@ -794,7 +794,8 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
self.assertIn(_COS_OP, instrument.graph_op_types)
# Check the ndarrays from runtime.
cos_op_outputs = instrument.graph_internal_ndarrays[_COS_OP]
cos_op_outputs = instrument.graph_internal_ndarrays[b"gradient_tape/" +
_COS_OP]
self.assertEqual(len(cos_op_outputs), 1)
self.assertAllClose(cos_op_outputs[0], np.cos(3.0 * 3.0))

View File

@ -582,9 +582,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
});
m.def("TFE_Py_RecordGradient",
[](const py::handle& op_name, const py::handle& inputs,
const py::handle& attrs, const py::handle& results) {
const py::handle& attrs, const py::handle& results,
const py::handle& forward_pass_name_scope) {
return tensorflow::pyo_or_throw(TFE_Py_RecordGradient(
op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr()));
op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
forward_pass_name_scope.ptr()));
});
m.def("TFE_Py_UID", []() { return tensorflow::pyo_or_throw(TFE_Py_UID()); });