Use the name scope of the forward pass in the backward pass.
PiperOrigin-RevId: 294732289 Change-Id: I15b20a033f66f4d4201b0eb6ea9be8b6cd82bf56
This commit is contained in:
parent
7e2063e88f
commit
77e9ffb9b2
@ -514,6 +514,7 @@ py_library(
|
||||
":tape",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import default_gradient
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
@ -123,7 +124,7 @@ class _MockOp(object):
|
||||
|
||||
|
||||
def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
|
||||
out_grads, skip_input_indices):
|
||||
out_grads, skip_input_indices, forward_pass_name_scope):
|
||||
"""Calls the gradient function of the op.
|
||||
|
||||
Args:
|
||||
@ -135,6 +136,7 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
|
||||
out_grads: gradients of the operation wrt its outputs.
|
||||
skip_input_indices: a tuple that is passed to the gradient function,
|
||||
indicating which inputs to skip calculating the gradient for
|
||||
forward_pass_name_scope: the namescope of the op in the forward pass.
|
||||
|
||||
Returns:
|
||||
The gradients with respect to the inputs of the function, as a list.
|
||||
@ -144,7 +146,17 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
|
||||
if grad_fn is None:
|
||||
return [None] * num_inputs
|
||||
|
||||
return grad_fn(mock_op, *out_grads)
|
||||
# This does not work with v1 TensorArrays.
|
||||
if ops.executing_eagerly_outside_functions(
|
||||
) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
|
||||
if forward_pass_name_scope:
|
||||
gradient_name_scope = "gradient_tape/" + forward_pass_name_scope + "/"
|
||||
else:
|
||||
gradient_name_scope = "gradient_tape/"
|
||||
with ops.name_scope(gradient_name_scope):
|
||||
return grad_fn(mock_op, *out_grads)
|
||||
else:
|
||||
return grad_fn(mock_op, *out_grads)
|
||||
|
||||
|
||||
pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
|
||||
@ -155,7 +167,8 @@ def _must_record_gradient():
|
||||
|
||||
|
||||
def _record_gradient(op_name, inputs, attrs, results):
|
||||
return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results)
|
||||
return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results,
|
||||
ops.get_name_scope())
|
||||
|
||||
|
||||
execute.must_record_gradient = _must_record_gradient
|
||||
|
@ -1512,6 +1512,27 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
tape.gradient(z, z)
|
||||
self.assertEqual((z,), tape.watched_variables())
|
||||
|
||||
def testNameScope(self):
|
||||
def fn(x):
|
||||
with ops.name_scope('my_scope'):
|
||||
a = math_ops.cos(x)
|
||||
b = math_ops.cos(x)
|
||||
return math_ops.add(a, b)
|
||||
|
||||
@function.defun
|
||||
def grad_fn(x):
|
||||
return backprop.gradients_function(fn)(x)
|
||||
|
||||
grad_ops = grad_fn.get_concrete_function(
|
||||
constant_op.constant(1.0)).graph.get_operations()
|
||||
num_sin_ops_found = 0
|
||||
for op in grad_ops:
|
||||
if op.type == 'Sin':
|
||||
num_sin_ops_found += 1
|
||||
self.assertIn('gradient_tape/my_scope/', op.name)
|
||||
self.assertEqual(num_sin_ops_found, 2)
|
||||
|
||||
|
||||
class JacobianTest(test.TestCase):
|
||||
|
||||
def _jacobian(self, experimental_use_pfor):
|
||||
|
@ -275,7 +275,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args);
|
||||
|
||||
// Record the gradient for a given op.
|
||||
PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
|
||||
PyObject* attrs, PyObject* results);
|
||||
PyObject* attrs, PyObject* results,
|
||||
PyObject* forward_pass_name_scope);
|
||||
|
||||
// Returns all variables watched by the given tape in the order those variables
|
||||
// were created.
|
||||
|
@ -2916,7 +2916,8 @@ PyObject* CopySequenceSettingIndicesToNull(
|
||||
}
|
||||
|
||||
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
PyObject* results) {
|
||||
PyObject* results,
|
||||
PyObject* forward_pass_name_scope = nullptr) {
|
||||
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
|
||||
@ -2997,16 +2998,21 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
|
||||
PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
|
||||
|
||||
if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
|
||||
|
||||
TapeSetRecordOperation(
|
||||
op_name, inputs, results, input_ids, input_dtypes,
|
||||
[op_name, attrs, num_inputs, op_inputs, op_outputs]() {
|
||||
[op_name, attrs, num_inputs, op_inputs, op_outputs,
|
||||
forward_pass_name_scope]() {
|
||||
Py_INCREF(op_name);
|
||||
Py_INCREF(attrs);
|
||||
Py_INCREF(num_inputs);
|
||||
Py_INCREF(op_inputs);
|
||||
Py_INCREF(op_outputs);
|
||||
Py_INCREF(forward_pass_name_scope);
|
||||
PyBackwardFunction* function = new PyBackwardFunction(
|
||||
[op_name, attrs, num_inputs, op_inputs, op_outputs](
|
||||
[op_name, attrs, num_inputs, op_inputs, op_outputs,
|
||||
forward_pass_name_scope](
|
||||
PyObject* output_grads,
|
||||
const std::vector<tensorflow::int64>& unneeded_gradients) {
|
||||
if (PyErr_Occurred()) {
|
||||
@ -3026,8 +3032,9 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
skip_input_indices.reset(Py_None);
|
||||
}
|
||||
tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
|
||||
"OOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
|
||||
output_grads, skip_input_indices.get()));
|
||||
"OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
|
||||
output_grads, skip_input_indices.get(),
|
||||
forward_pass_name_scope));
|
||||
|
||||
tensorflow::Safe_PyObjectPtr result(
|
||||
PyObject_CallObject(gradient_function, callback_args.get()));
|
||||
@ -3038,13 +3045,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
});
|
||||
return function;
|
||||
},
|
||||
[op_name, attrs, num_inputs, op_inputs,
|
||||
op_outputs](PyBackwardFunction* backward_function) {
|
||||
[op_name, attrs, num_inputs, op_inputs, op_outputs,
|
||||
forward_pass_name_scope](PyBackwardFunction* backward_function) {
|
||||
Py_DECREF(op_name);
|
||||
Py_DECREF(attrs);
|
||||
Py_DECREF(num_inputs);
|
||||
Py_DECREF(op_inputs);
|
||||
Py_DECREF(op_outputs);
|
||||
Py_DECREF(forward_pass_name_scope);
|
||||
|
||||
delete backward_function;
|
||||
},
|
||||
@ -3668,12 +3676,14 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
|
||||
PyObject* attrs, PyObject* results) {
|
||||
PyObject* attrs, PyObject* results,
|
||||
PyObject* forward_pass_name_scope) {
|
||||
if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
return RecordGradient(op_name, inputs, attrs, results);
|
||||
return RecordGradient(op_name, inputs, attrs, results,
|
||||
forward_pass_name_scope);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -794,7 +794,8 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
|
||||
self.assertIn(_COS_OP, instrument.graph_op_types)
|
||||
|
||||
# Check the ndarrays from runtime.
|
||||
cos_op_outputs = instrument.graph_internal_ndarrays[_COS_OP]
|
||||
cos_op_outputs = instrument.graph_internal_ndarrays[b"gradient_tape/" +
|
||||
_COS_OP]
|
||||
self.assertEqual(len(cos_op_outputs), 1)
|
||||
self.assertAllClose(cos_op_outputs[0], np.cos(3.0 * 3.0))
|
||||
|
||||
|
@ -582,9 +582,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
});
|
||||
m.def("TFE_Py_RecordGradient",
|
||||
[](const py::handle& op_name, const py::handle& inputs,
|
||||
const py::handle& attrs, const py::handle& results) {
|
||||
const py::handle& attrs, const py::handle& results,
|
||||
const py::handle& forward_pass_name_scope) {
|
||||
return tensorflow::pyo_or_throw(TFE_Py_RecordGradient(
|
||||
op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr()));
|
||||
op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
|
||||
forward_pass_name_scope.ptr()));
|
||||
});
|
||||
m.def("TFE_Py_UID", []() { return tensorflow::pyo_or_throw(TFE_Py_UID()); });
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user