Do not differentiate integers in the eager backprop API.
(with bugfix) PiperOrigin-RevId: 196184587
This commit is contained in:
parent
f7e24ab111
commit
8a8dddf8bd
@ -130,13 +130,15 @@ class GradientTape {
|
||||
}
|
||||
}
|
||||
|
||||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
|
||||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes);
|
||||
|
||||
void Watch(int64 tensor_id);
|
||||
|
||||
void RecordOperation(const string& op_type,
|
||||
gtl::ArraySlice<TapeTensor> output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||
BackwardFunction* backward_function,
|
||||
const std::function<void()>& backward_function_deleter);
|
||||
|
||||
@ -170,12 +172,32 @@ class GradientTape {
|
||||
|
||||
// Template instantiations here
|
||||
|
||||
inline bool IsDtypeTrainable(DataType dtype) {
|
||||
switch (dtype) {
|
||||
case DT_HALF:
|
||||
case DT_BFLOAT16:
|
||||
case DT_FLOAT:
|
||||
case DT_DOUBLE:
|
||||
case DT_COMPLEX64:
|
||||
case DT_COMPLEX128:
|
||||
case DT_RESOURCE:
|
||||
case DT_VARIANT:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Gradient, typename BackwardFunction>
|
||||
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
|
||||
gtl::ArraySlice<int64> tensor_ids) {
|
||||
for (int64 i : tensor_ids) {
|
||||
if (tensor_tape_.find(i) != tensor_tape_.end()) {
|
||||
return true;
|
||||
gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes) {
|
||||
CHECK_EQ(tensor_ids.size(), dtypes.size());
|
||||
for (int i = 0; i < tensor_ids.size(); ++i) {
|
||||
if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
|
||||
if (IsDtypeTrainable(dtypes[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
@ -189,9 +211,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
|
||||
template <typename Gradient, typename BackwardFunction>
|
||||
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
|
||||
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||
BackwardFunction* backward_function,
|
||||
const std::function<void()>& backward_function_deleter) {
|
||||
if (!ShouldRecord(input_tensor_id)) {
|
||||
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
|
||||
backward_function_deleter();
|
||||
return;
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
return math_ops.multiply(x, x)
|
||||
|
||||
grad = tfe.gradients_function(square)
|
||||
self.assertEquals([6], [x.numpy() for x in grad(3)])
|
||||
self.assertEquals([6], [x.numpy() for x in grad(3.)])
|
||||
|
||||
def testGradOfGrad(self):
|
||||
|
||||
@ -66,7 +66,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
|
||||
grad = tfe.gradients_function(square)
|
||||
gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
|
||||
self.assertEquals([2], [x.numpy() for x in gradgrad(3)])
|
||||
self.assertEquals([2], [x.numpy() for x in gradgrad(3.)])
|
||||
|
||||
def testCustomGrad(self):
|
||||
|
||||
@ -80,7 +80,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
return y, grad_fn
|
||||
|
||||
grad = tfe.gradients_function(f)
|
||||
self.assertEquals([12], [x.numpy() for x in grad(3)])
|
||||
self.assertEquals([12], [x.numpy() for x in grad(3.)])
|
||||
|
||||
def testGPU(self):
|
||||
if tfe.num_gpus() <= 0:
|
||||
|
@ -358,6 +358,8 @@ def gradients_function(f, params=None):
|
||||
assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
|
||||
```
|
||||
|
||||
Note that only tensors with real or complex dtypes are differentiable.
|
||||
|
||||
Args:
|
||||
f: function to be differentiated. If `f` returns a scalar, this scalar will
|
||||
be differentiated. If `f` returns a tensor or list of tensors, by default
|
||||
@ -700,6 +702,9 @@ class GradientTape(object):
|
||||
dz_dx = g.gradient(z, x) # 108.0 (4*x^3 at x = 3)
|
||||
dy_dx = g.gradient(y, x) # 6.0
|
||||
del g # Drop the reference to the tape
|
||||
```
|
||||
|
||||
Note that only tensors with real or complex dtypes are differentiable.
|
||||
"""
|
||||
|
||||
def __init__(self, persistent=False):
|
||||
|
@ -96,6 +96,18 @@ class BackpropTest(test.TestCase):
|
||||
self.assertAllEqual(grads_and_vars[0][0], 1.0)
|
||||
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
|
||||
|
||||
def testWhereGradient(self):
|
||||
# Note: where is special because only some of its arguments are of
|
||||
# differentiable dtypes.
|
||||
|
||||
def f(x):
|
||||
return array_ops.where(x < 10, x, x * x)
|
||||
|
||||
g = backprop.gradients_function(f)
|
||||
|
||||
self.assertAllEqual(g(5.)[0], 1.0)
|
||||
self.assertAllEqual(g(50.)[0], 100.0)
|
||||
|
||||
def testTwoTargets(self):
|
||||
with backprop.GradientTape() as t:
|
||||
x = constant_op.constant(3.0)
|
||||
@ -124,6 +136,14 @@ class BackpropTest(test.TestCase):
|
||||
grad_fn = backprop.gradients_function(f)
|
||||
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
|
||||
|
||||
def testGradientInteger(self):
|
||||
|
||||
def f(x):
|
||||
return x + x
|
||||
|
||||
int_tensor = constant_op.constant(1)
|
||||
self.assertEqual(backprop.gradients_function(f)(int_tensor)[0], None)
|
||||
|
||||
def testErrors(self):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
@ -753,7 +773,7 @@ class BackpropTest(test.TestCase):
|
||||
return result, grad
|
||||
|
||||
x = resource_variable_ops.ResourceVariable(
|
||||
initial_value=3, name='X.' + self.id())
|
||||
initial_value=3., name='X.' + self.id())
|
||||
|
||||
def f():
|
||||
return my_square(x)
|
||||
|
@ -650,6 +650,12 @@ tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
|
||||
return reinterpret_cast<const EagerTensor*>(tensor)->id;
|
||||
}
|
||||
|
||||
tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
|
||||
CHECK(EagerTensor_CheckExact(tensor));
|
||||
return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
|
||||
reinterpret_cast<const EagerTensor*>(tensor)->handle));
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
||||
if (!PyType_Check(base_class)) {
|
||||
PyErr_SetString(
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
bool EagerTensor_CheckExact(const PyObject* o);
|
||||
tensorflow::int64 EagerTensor_id(const PyObject* tensor);
|
||||
tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
|
||||
|
||||
namespace tensorflow {
|
||||
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
|
||||
|
@ -843,6 +843,24 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
|
||||
return id;
|
||||
}
|
||||
|
||||
static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
|
||||
if (EagerTensor_CheckExact(tensor)) {
|
||||
return EagerTensor_dtype(tensor);
|
||||
}
|
||||
PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
|
||||
if (dtype_field == nullptr) {
|
||||
return tensorflow::DT_INVALID;
|
||||
}
|
||||
PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
|
||||
Py_DECREF(dtype_field);
|
||||
if (dtype_field == nullptr) {
|
||||
return tensorflow::DT_INVALID;
|
||||
}
|
||||
tensorflow::int64 id = MakeInt(enum_field);
|
||||
Py_DECREF(enum_field);
|
||||
return static_cast<tensorflow::DataType>(id);
|
||||
}
|
||||
|
||||
class GradientTape
|
||||
: public tensorflow::eager::GradientTape<PyObject, PyObject> {
|
||||
public:
|
||||
@ -1053,15 +1071,18 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
||||
// TODO(apassos) consider not building a list and changing the API to check
|
||||
// each tensor individually.
|
||||
std::vector<tensorflow::int64> tensor_ids;
|
||||
std::vector<tensorflow::DataType> dtypes;
|
||||
tensor_ids.reserve(len);
|
||||
dtypes.reserve(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
|
||||
tensor_ids.push_back(FastTensorId(item));
|
||||
dtypes.push_back(FastTensorDtype(item));
|
||||
}
|
||||
Py_DECREF(seq);
|
||||
auto tape_set = *tape_set_ptr;
|
||||
for (TFE_Py_Tape* tape : tape_set) {
|
||||
if (tape->tape->ShouldRecord(tensor_ids)) {
|
||||
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
}
|
||||
@ -1169,9 +1190,27 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
|
||||
const std::vector<tensorflow::int64>& input_ids,
|
||||
PyObject* backward_function) {
|
||||
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
|
||||
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
||||
if (seq == nullptr) {
|
||||
return {};
|
||||
}
|
||||
int len = PySequence_Fast_GET_SIZE(seq);
|
||||
std::vector<tensorflow::DataType> list;
|
||||
list.reserve(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
|
||||
list.push_back(FastTensorDtype(tensor));
|
||||
}
|
||||
Py_DECREF(seq);
|
||||
return list;
|
||||
}
|
||||
|
||||
void TapeSetRecordOperation(
|
||||
PyObject* op_type, PyObject* output_tensors,
|
||||
const std::vector<tensorflow::int64>& input_ids,
|
||||
const std::vector<tensorflow::DataType>& input_dtypes,
|
||||
PyObject* backward_function) {
|
||||
std::vector<tensorflow::eager::TapeTensor> output_info;
|
||||
PyObject* seq = PySequence_Fast(output_tensors,
|
||||
"expected a sequence of integer tensor ids");
|
||||
@ -1206,7 +1245,7 @@ void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
|
||||
for (TFE_Py_Tape* tape : SafeTapeSet()) {
|
||||
Py_INCREF(backward_function);
|
||||
tape->tape->RecordOperation(
|
||||
op_type_str, output_info, input_ids, backward_function,
|
||||
op_type_str, output_info, input_ids, input_dtypes, backward_function,
|
||||
[backward_function]() { Py_DECREF(backward_function); });
|
||||
}
|
||||
}
|
||||
@ -1221,7 +1260,11 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
|
||||
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
|
||||
if (PyErr_Occurred()) return;
|
||||
|
||||
TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function);
|
||||
std::vector<tensorflow::DataType> input_dtypes =
|
||||
MakeTensorDtypeList(input_tensors);
|
||||
if (PyErr_Occurred()) return;
|
||||
TapeSetRecordOperation(op_type, output_tensors, input_ids, input_dtypes,
|
||||
backward_function);
|
||||
}
|
||||
|
||||
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
|
||||
@ -1710,10 +1753,12 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
PyObject* results, PyObject* name) {
|
||||
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
|
||||
bool should_record = false;
|
||||
for (TFE_Py_Tape* tape : SafeTapeSet()) {
|
||||
if (tape->tape->ShouldRecord(input_ids)) {
|
||||
if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
|
||||
should_record = true;
|
||||
break;
|
||||
}
|
||||
@ -1744,7 +1789,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
Py_DECREF(callback_args);
|
||||
if (backward_function == nullptr) return nullptr;
|
||||
|
||||
TapeSetRecordOperation(op_name, results, input_ids, backward_function);
|
||||
TapeSetRecordOperation(op_name, results, input_ids, input_dtypes,
|
||||
backward_function);
|
||||
|
||||
Py_DECREF(backward_function);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user