Do not differentiate integers in the eager backprop API.

(with bugfix)

PiperOrigin-RevId: 196184587
This commit is contained in:
Alexandre Passos 2018-05-10 15:54:13 -07:00 committed by TensorFlower Gardener
parent f7e24ab111
commit 8a8dddf8bd
7 changed files with 121 additions and 19 deletions

View File

@ -130,13 +130,15 @@ class GradientTape {
}
}
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes);
void Watch(int64 tensor_id);
void RecordOperation(const string& op_type,
gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter);
@ -170,12 +172,32 @@ class GradientTape {
// Template instantiations here
inline bool IsDtypeTrainable(DataType dtype) {
switch (dtype) {
case DT_HALF:
case DT_BFLOAT16:
case DT_FLOAT:
case DT_DOUBLE:
case DT_COMPLEX64:
case DT_COMPLEX128:
case DT_RESOURCE:
case DT_VARIANT:
return true;
default:
return false;
}
}
template <typename Gradient, typename BackwardFunction>
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids) {
for (int64 i : tensor_ids) {
if (tensor_tape_.find(i) != tensor_tape_.end()) {
return true;
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
CHECK_EQ(tensor_ids.size(), dtypes.size());
for (int i = 0; i < tensor_ids.size(); ++i) {
if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
if (IsDtypeTrainable(dtypes[i])) {
return true;
}
}
}
return false;
@ -189,9 +211,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
template <typename Gradient, typename BackwardFunction>
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id)) {
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
backward_function_deleter();
return;
}

View File

@ -57,7 +57,7 @@ class TFETest(test_util.TensorFlowTestCase):
return math_ops.multiply(x, x)
grad = tfe.gradients_function(square)
self.assertEquals([6], [x.numpy() for x in grad(3)])
self.assertEquals([6], [x.numpy() for x in grad(3.)])
def testGradOfGrad(self):
@ -66,7 +66,7 @@ class TFETest(test_util.TensorFlowTestCase):
grad = tfe.gradients_function(square)
gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
self.assertEquals([2], [x.numpy() for x in gradgrad(3)])
self.assertEquals([2], [x.numpy() for x in gradgrad(3.)])
def testCustomGrad(self):
@ -80,7 +80,7 @@ class TFETest(test_util.TensorFlowTestCase):
return y, grad_fn
grad = tfe.gradients_function(f)
self.assertEquals([12], [x.numpy() for x in grad(3)])
self.assertEquals([12], [x.numpy() for x in grad(3.)])
def testGPU(self):
if tfe.num_gpus() <= 0:

View File

@ -358,6 +358,8 @@ def gradients_function(f, params=None):
assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
```
Note that only tensors with real or complex dtypes are differentiable.
Args:
f: function to be differentiated. If `f` returns a scalar, this scalar will
be differentiated. If `f` returns a tensor or list of tensors, by default
@ -700,6 +702,9 @@ class GradientTape(object):
dz_dx = g.gradient(z, x) # 108.0 (4*x^3 at x = 3)
dy_dx = g.gradient(y, x) # 6.0
del g # Drop the reference to the tape
```
Note that only tensors with real or complex dtypes are differentiable.
"""
def __init__(self, persistent=False):

View File

@ -96,6 +96,18 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(grads_and_vars[0][0], 1.0)
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
def testWhereGradient(self):
# Note: where is special because only some of its arguments are of
# differentiable dtypes.
def f(x):
return array_ops.where(x < 10, x, x * x)
g = backprop.gradients_function(f)
self.assertAllEqual(g(5.)[0], 1.0)
self.assertAllEqual(g(50.)[0], 100.0)
def testTwoTargets(self):
with backprop.GradientTape() as t:
x = constant_op.constant(3.0)
@ -124,6 +136,14 @@ class BackpropTest(test.TestCase):
grad_fn = backprop.gradients_function(f)
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
def testGradientInteger(self):
def f(x):
return x + x
int_tensor = constant_op.constant(1)
self.assertEqual(backprop.gradients_function(f)(int_tensor)[0], None)
def testErrors(self):
@custom_gradient.custom_gradient
@ -753,7 +773,7 @@ class BackpropTest(test.TestCase):
return result, grad
x = resource_variable_ops.ResourceVariable(
initial_value=3, name='X.' + self.id())
initial_value=3., name='X.' + self.id())
def f():
return my_square(x)

View File

@ -650,6 +650,12 @@ tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
return reinterpret_cast<const EagerTensor*>(tensor)->id;
}
tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
CHECK(EagerTensor_CheckExact(tensor));
return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
reinterpret_cast<const EagerTensor*>(tensor)->handle));
}
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
if (!PyType_Check(base_class)) {
PyErr_SetString(

View File

@ -22,6 +22,7 @@ limitations under the License.
bool EagerTensor_CheckExact(const PyObject* o);
tensorflow::int64 EagerTensor_id(const PyObject* tensor);
tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
namespace tensorflow {
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);

View File

@ -843,6 +843,24 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
return id;
}
static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
return EagerTensor_dtype(tensor);
}
PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
if (dtype_field == nullptr) {
return tensorflow::DT_INVALID;
}
PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
Py_DECREF(dtype_field);
if (dtype_field == nullptr) {
return tensorflow::DT_INVALID;
}
tensorflow::int64 id = MakeInt(enum_field);
Py_DECREF(enum_field);
return static_cast<tensorflow::DataType>(id);
}
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyObject> {
public:
@ -1053,15 +1071,18 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
// TODO(apassos) consider not building a list and changing the API to check
// each tensor individually.
std::vector<tensorflow::int64> tensor_ids;
std::vector<tensorflow::DataType> dtypes;
tensor_ids.reserve(len);
dtypes.reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
tensor_ids.push_back(FastTensorId(item));
dtypes.push_back(FastTensorDtype(item));
}
Py_DECREF(seq);
auto tape_set = *tape_set_ptr;
for (TFE_Py_Tape* tape : tape_set) {
if (tape->tape->ShouldRecord(tensor_ids)) {
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
Py_RETURN_TRUE;
}
}
@ -1169,9 +1190,27 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
}
namespace {
void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
const std::vector<tensorflow::int64>& input_ids,
PyObject* backward_function) {
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
return {};
}
int len = PySequence_Fast_GET_SIZE(seq);
std::vector<tensorflow::DataType> list;
list.reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
list.push_back(FastTensorDtype(tensor));
}
Py_DECREF(seq);
return list;
}
void TapeSetRecordOperation(
PyObject* op_type, PyObject* output_tensors,
const std::vector<tensorflow::int64>& input_ids,
const std::vector<tensorflow::DataType>& input_dtypes,
PyObject* backward_function) {
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
@ -1206,7 +1245,7 @@ void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
for (TFE_Py_Tape* tape : SafeTapeSet()) {
Py_INCREF(backward_function);
tape->tape->RecordOperation(
op_type_str, output_info, input_ids, backward_function,
op_type_str, output_info, input_ids, input_dtypes, backward_function,
[backward_function]() { Py_DECREF(backward_function); });
}
}
@ -1221,7 +1260,11 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
if (PyErr_Occurred()) return;
TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function);
std::vector<tensorflow::DataType> input_dtypes =
MakeTensorDtypeList(input_tensors);
if (PyErr_Occurred()) return;
TapeSetRecordOperation(op_type, output_tensors, input_ids, input_dtypes,
backward_function);
}
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
@ -1710,10 +1753,12 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
PyObject* results, PyObject* name) {
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
if (PyErr_Occurred()) return nullptr;
std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
if (PyErr_Occurred()) return nullptr;
bool should_record = false;
for (TFE_Py_Tape* tape : SafeTapeSet()) {
if (tape->tape->ShouldRecord(input_ids)) {
if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
should_record = true;
break;
}
@ -1744,7 +1789,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
Py_DECREF(callback_args);
if (backward_function == nullptr) return nullptr;
TapeSetRecordOperation(op_name, results, input_ids, backward_function);
TapeSetRecordOperation(op_name, results, input_ids, input_dtypes,
backward_function);
Py_DECREF(backward_function);