Automated g4 rollback of changelist 279842730.
PiperOrigin-RevId: 281314244 Change-Id: I147be72280cc2f8b114437ec6917777a9739111d
This commit is contained in:
parent
30240224dc
commit
af594a2116
@ -116,14 +116,13 @@ class _MockOp(object):
|
||||
)
|
||||
|
||||
|
||||
def _gradient_function(op_name, attr_tuple, device, num_inputs, inputs, outputs,
|
||||
def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
|
||||
out_grads, skip_input_indices):
|
||||
"""Calls the gradient function of the op.
|
||||
|
||||
Args:
|
||||
op_name: the name of the op to be differentiated.
|
||||
attr_tuple: the attrs, as a tuple.
|
||||
device: the device of the op.
|
||||
num_inputs: the number of inputs to the op.
|
||||
inputs: inputs to the original operation.
|
||||
outputs: outputs to the original operation.
|
||||
@ -139,8 +138,7 @@ def _gradient_function(op_name, attr_tuple, device, num_inputs, inputs, outputs,
|
||||
if grad_fn is None:
|
||||
return [None] * num_inputs
|
||||
|
||||
with ops.device(device):
|
||||
return grad_fn(mock_op, *out_grads)
|
||||
return grad_fn(mock_op, *out_grads)
|
||||
|
||||
|
||||
pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function)
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import remote
|
||||
@ -28,7 +27,6 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
class SoftDevicePlacementTest(test.TestCase):
|
||||
@ -88,23 +86,6 @@ class SoftDevicePlacementTest(test.TestCase):
|
||||
# We don't support nested device placement right now.
|
||||
self.assertIn('GPU:0', c.device)
|
||||
|
||||
@test_util.run_gpu_only
|
||||
def testGradientPlacement(self):
|
||||
with ops.device('GPU:0'):
|
||||
x = variables.Variable(1.0)
|
||||
with ops.device('CPU:0'):
|
||||
y = variables.Variable(1.0)
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
with ops.device('GPU:0'):
|
||||
x1 = constant_op.constant(2.0) * x
|
||||
with ops.device('CPU:0'):
|
||||
y1 = constant_op.constant(2.0) * y
|
||||
z = x1 + y1
|
||||
grads = tape.gradient(z, [x, y])
|
||||
self.assertIn('GPU:0', grads[0].device)
|
||||
self.assertIn('CPU:0', grads[1].device)
|
||||
|
||||
|
||||
class ClusterPlacementTest(test.TestCase):
|
||||
|
||||
|
@ -3007,22 +3007,6 @@ PyObject* CopySequenceSettingIndicesToNull(
|
||||
return result;
|
||||
}
|
||||
|
||||
PyObject* DeviceFromTensorSeq(PyObject* seq) {
|
||||
for (Py_ssize_t i = 0; i < PySequence_Size(seq); i++) {
|
||||
PyObject* item = PySequence_ITEM(seq, i);
|
||||
PyObject* dev = PyObject_GetAttrString(item, "device");
|
||||
Py_DECREF(item);
|
||||
if (dev) {
|
||||
const char* devStr = TFE_GetPythonString(dev);
|
||||
if (devStr && !string(devStr).empty()) {
|
||||
return dev;
|
||||
}
|
||||
Py_DECREF(dev);
|
||||
}
|
||||
}
|
||||
return Py_None;
|
||||
}
|
||||
|
||||
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
PyObject* results) {
|
||||
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
|
||||
@ -3049,11 +3033,6 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
|
||||
string c_op_name = TFE_GetPythonString(op_name);
|
||||
|
||||
PyObject* device = DeviceFromTensorSeq(results);
|
||||
if (device == Py_None) {
|
||||
device = DeviceFromTensorSeq(inputs);
|
||||
}
|
||||
|
||||
PyObject* op_outputs;
|
||||
bool op_outputs_tuple_created = false;
|
||||
std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
|
||||
@ -3112,15 +3091,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
|
||||
TapeSetRecordOperation(
|
||||
op_name, inputs, results, input_ids, input_dtypes,
|
||||
[op_name, attrs, device, num_inputs, op_inputs, op_outputs]() {
|
||||
[op_name, attrs, num_inputs, op_inputs, op_outputs]() {
|
||||
Py_INCREF(op_name);
|
||||
Py_INCREF(attrs);
|
||||
Py_INCREF(device);
|
||||
Py_INCREF(num_inputs);
|
||||
Py_INCREF(op_inputs);
|
||||
Py_INCREF(op_outputs);
|
||||
PyBackwardFunction* function = new PyBackwardFunction(
|
||||
[op_name, attrs, device, num_inputs, op_inputs, op_outputs](
|
||||
[op_name, attrs, num_inputs, op_inputs, op_outputs](
|
||||
PyObject* output_grads,
|
||||
const std::vector<tensorflow::int64>& unneeded_gradients) {
|
||||
if (PyErr_Occurred()) {
|
||||
@ -3140,8 +3118,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
skip_input_indices.reset(Py_None);
|
||||
}
|
||||
tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
|
||||
"OOOOOOOO", op_name, attrs, device, num_inputs, op_inputs,
|
||||
op_outputs, output_grads, skip_input_indices.get()));
|
||||
"OOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
|
||||
output_grads, skip_input_indices.get()));
|
||||
|
||||
tensorflow::Safe_PyObjectPtr result(
|
||||
PyObject_CallObject(gradient_function, callback_args.get()));
|
||||
@ -3152,11 +3130,10 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
});
|
||||
return function;
|
||||
},
|
||||
[op_name, attrs, device, num_inputs, op_inputs,
|
||||
[op_name, attrs, num_inputs, op_inputs,
|
||||
op_outputs](PyBackwardFunction* backward_function) {
|
||||
Py_DECREF(op_name);
|
||||
Py_DECREF(attrs);
|
||||
Py_DECREF(device);
|
||||
Py_DECREF(num_inputs);
|
||||
Py_DECREF(op_inputs);
|
||||
Py_DECREF(op_outputs);
|
||||
@ -3166,7 +3143,6 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
forward_function);
|
||||
|
||||
Py_DECREF(num_inputs);
|
||||
Py_DECREF(device);
|
||||
if (op_outputs_tuple_created) Py_DECREF(op_outputs);
|
||||
if (op_inputs_tuple_created) Py_DECREF(op_inputs);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user