Do not differentiate integers in the eager backprop API.
(with bugfix) PiperOrigin-RevId: 196184587
This commit is contained in:
parent
f7e24ab111
commit
8a8dddf8bd
@ -130,13 +130,15 @@ class GradientTape {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
|
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
|
||||||
|
gtl::ArraySlice<tensorflow::DataType> dtypes);
|
||||||
|
|
||||||
void Watch(int64 tensor_id);
|
void Watch(int64 tensor_id);
|
||||||
|
|
||||||
void RecordOperation(const string& op_type,
|
void RecordOperation(const string& op_type,
|
||||||
gtl::ArraySlice<TapeTensor> output_tensors,
|
gtl::ArraySlice<TapeTensor> output_tensors,
|
||||||
gtl::ArraySlice<int64> input_tensor_id,
|
gtl::ArraySlice<int64> input_tensor_id,
|
||||||
|
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||||
BackwardFunction* backward_function,
|
BackwardFunction* backward_function,
|
||||||
const std::function<void()>& backward_function_deleter);
|
const std::function<void()>& backward_function_deleter);
|
||||||
|
|
||||||
@ -170,14 +172,34 @@ class GradientTape {
|
|||||||
|
|
||||||
// Template instantiations here
|
// Template instantiations here
|
||||||
|
|
||||||
|
inline bool IsDtypeTrainable(DataType dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case DT_HALF:
|
||||||
|
case DT_BFLOAT16:
|
||||||
|
case DT_FLOAT:
|
||||||
|
case DT_DOUBLE:
|
||||||
|
case DT_COMPLEX64:
|
||||||
|
case DT_COMPLEX128:
|
||||||
|
case DT_RESOURCE:
|
||||||
|
case DT_VARIANT:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Gradient, typename BackwardFunction>
|
template <typename Gradient, typename BackwardFunction>
|
||||||
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
|
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
|
||||||
gtl::ArraySlice<int64> tensor_ids) {
|
gtl::ArraySlice<int64> tensor_ids,
|
||||||
for (int64 i : tensor_ids) {
|
gtl::ArraySlice<tensorflow::DataType> dtypes) {
|
||||||
if (tensor_tape_.find(i) != tensor_tape_.end()) {
|
CHECK_EQ(tensor_ids.size(), dtypes.size());
|
||||||
|
for (int i = 0; i < tensor_ids.size(); ++i) {
|
||||||
|
if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
|
||||||
|
if (IsDtypeTrainable(dtypes[i])) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,9 +211,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
|
|||||||
template <typename Gradient, typename BackwardFunction>
|
template <typename Gradient, typename BackwardFunction>
|
||||||
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
|
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
|
||||||
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
|
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
|
||||||
gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
|
gtl::ArraySlice<int64> input_tensor_id,
|
||||||
|
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||||
|
BackwardFunction* backward_function,
|
||||||
const std::function<void()>& backward_function_deleter) {
|
const std::function<void()>& backward_function_deleter) {
|
||||||
if (!ShouldRecord(input_tensor_id)) {
|
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
|
||||||
backward_function_deleter();
|
backward_function_deleter();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -57,7 +57,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
return math_ops.multiply(x, x)
|
return math_ops.multiply(x, x)
|
||||||
|
|
||||||
grad = tfe.gradients_function(square)
|
grad = tfe.gradients_function(square)
|
||||||
self.assertEquals([6], [x.numpy() for x in grad(3)])
|
self.assertEquals([6], [x.numpy() for x in grad(3.)])
|
||||||
|
|
||||||
def testGradOfGrad(self):
|
def testGradOfGrad(self):
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
grad = tfe.gradients_function(square)
|
grad = tfe.gradients_function(square)
|
||||||
gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
|
gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
|
||||||
self.assertEquals([2], [x.numpy() for x in gradgrad(3)])
|
self.assertEquals([2], [x.numpy() for x in gradgrad(3.)])
|
||||||
|
|
||||||
def testCustomGrad(self):
|
def testCustomGrad(self):
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
return y, grad_fn
|
return y, grad_fn
|
||||||
|
|
||||||
grad = tfe.gradients_function(f)
|
grad = tfe.gradients_function(f)
|
||||||
self.assertEquals([12], [x.numpy() for x in grad(3)])
|
self.assertEquals([12], [x.numpy() for x in grad(3.)])
|
||||||
|
|
||||||
def testGPU(self):
|
def testGPU(self):
|
||||||
if tfe.num_gpus() <= 0:
|
if tfe.num_gpus() <= 0:
|
||||||
|
@ -358,6 +358,8 @@ def gradients_function(f, params=None):
|
|||||||
assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
|
assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note that only tensors with real or complex dtypes are differentiable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: function to be differentiated. If `f` returns a scalar, this scalar will
|
f: function to be differentiated. If `f` returns a scalar, this scalar will
|
||||||
be differentiated. If `f` returns a tensor or list of tensors, by default
|
be differentiated. If `f` returns a tensor or list of tensors, by default
|
||||||
@ -700,6 +702,9 @@ class GradientTape(object):
|
|||||||
dz_dx = g.gradient(z, x) # 108.0 (4*x^3 at x = 3)
|
dz_dx = g.gradient(z, x) # 108.0 (4*x^3 at x = 3)
|
||||||
dy_dx = g.gradient(y, x) # 6.0
|
dy_dx = g.gradient(y, x) # 6.0
|
||||||
del g # Drop the reference to the tape
|
del g # Drop the reference to the tape
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that only tensors with real or complex dtypes are differentiable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, persistent=False):
|
def __init__(self, persistent=False):
|
||||||
|
@ -96,6 +96,18 @@ class BackpropTest(test.TestCase):
|
|||||||
self.assertAllEqual(grads_and_vars[0][0], 1.0)
|
self.assertAllEqual(grads_and_vars[0][0], 1.0)
|
||||||
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
|
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
|
||||||
|
|
||||||
|
def testWhereGradient(self):
|
||||||
|
# Note: where is special because only some of its arguments are of
|
||||||
|
# differentiable dtypes.
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return array_ops.where(x < 10, x, x * x)
|
||||||
|
|
||||||
|
g = backprop.gradients_function(f)
|
||||||
|
|
||||||
|
self.assertAllEqual(g(5.)[0], 1.0)
|
||||||
|
self.assertAllEqual(g(50.)[0], 100.0)
|
||||||
|
|
||||||
def testTwoTargets(self):
|
def testTwoTargets(self):
|
||||||
with backprop.GradientTape() as t:
|
with backprop.GradientTape() as t:
|
||||||
x = constant_op.constant(3.0)
|
x = constant_op.constant(3.0)
|
||||||
@ -124,6 +136,14 @@ class BackpropTest(test.TestCase):
|
|||||||
grad_fn = backprop.gradients_function(f)
|
grad_fn = backprop.gradients_function(f)
|
||||||
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
|
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
|
||||||
|
|
||||||
|
def testGradientInteger(self):
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return x + x
|
||||||
|
|
||||||
|
int_tensor = constant_op.constant(1)
|
||||||
|
self.assertEqual(backprop.gradients_function(f)(int_tensor)[0], None)
|
||||||
|
|
||||||
def testErrors(self):
|
def testErrors(self):
|
||||||
|
|
||||||
@custom_gradient.custom_gradient
|
@custom_gradient.custom_gradient
|
||||||
@ -753,7 +773,7 @@ class BackpropTest(test.TestCase):
|
|||||||
return result, grad
|
return result, grad
|
||||||
|
|
||||||
x = resource_variable_ops.ResourceVariable(
|
x = resource_variable_ops.ResourceVariable(
|
||||||
initial_value=3, name='X.' + self.id())
|
initial_value=3., name='X.' + self.id())
|
||||||
|
|
||||||
def f():
|
def f():
|
||||||
return my_square(x)
|
return my_square(x)
|
||||||
|
@ -650,6 +650,12 @@ tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
|
|||||||
return reinterpret_cast<const EagerTensor*>(tensor)->id;
|
return reinterpret_cast<const EagerTensor*>(tensor)->id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
|
||||||
|
CHECK(EagerTensor_CheckExact(tensor));
|
||||||
|
return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
|
||||||
|
reinterpret_cast<const EagerTensor*>(tensor)->handle));
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
||||||
if (!PyType_Check(base_class)) {
|
if (!PyType_Check(base_class)) {
|
||||||
PyErr_SetString(
|
PyErr_SetString(
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
|
|
||||||
bool EagerTensor_CheckExact(const PyObject* o);
|
bool EagerTensor_CheckExact(const PyObject* o);
|
||||||
tensorflow::int64 EagerTensor_id(const PyObject* tensor);
|
tensorflow::int64 EagerTensor_id(const PyObject* tensor);
|
||||||
|
tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
|
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
|
||||||
|
@ -843,6 +843,24 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
|
|||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
|
||||||
|
if (EagerTensor_CheckExact(tensor)) {
|
||||||
|
return EagerTensor_dtype(tensor);
|
||||||
|
}
|
||||||
|
PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
|
||||||
|
if (dtype_field == nullptr) {
|
||||||
|
return tensorflow::DT_INVALID;
|
||||||
|
}
|
||||||
|
PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
|
||||||
|
Py_DECREF(dtype_field);
|
||||||
|
if (dtype_field == nullptr) {
|
||||||
|
return tensorflow::DT_INVALID;
|
||||||
|
}
|
||||||
|
tensorflow::int64 id = MakeInt(enum_field);
|
||||||
|
Py_DECREF(enum_field);
|
||||||
|
return static_cast<tensorflow::DataType>(id);
|
||||||
|
}
|
||||||
|
|
||||||
class GradientTape
|
class GradientTape
|
||||||
: public tensorflow::eager::GradientTape<PyObject, PyObject> {
|
: public tensorflow::eager::GradientTape<PyObject, PyObject> {
|
||||||
public:
|
public:
|
||||||
@ -1053,15 +1071,18 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
|||||||
// TODO(apassos) consider not building a list and changing the API to check
|
// TODO(apassos) consider not building a list and changing the API to check
|
||||||
// each tensor individually.
|
// each tensor individually.
|
||||||
std::vector<tensorflow::int64> tensor_ids;
|
std::vector<tensorflow::int64> tensor_ids;
|
||||||
|
std::vector<tensorflow::DataType> dtypes;
|
||||||
tensor_ids.reserve(len);
|
tensor_ids.reserve(len);
|
||||||
|
dtypes.reserve(len);
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
|
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
|
||||||
tensor_ids.push_back(FastTensorId(item));
|
tensor_ids.push_back(FastTensorId(item));
|
||||||
|
dtypes.push_back(FastTensorDtype(item));
|
||||||
}
|
}
|
||||||
Py_DECREF(seq);
|
Py_DECREF(seq);
|
||||||
auto tape_set = *tape_set_ptr;
|
auto tape_set = *tape_set_ptr;
|
||||||
for (TFE_Py_Tape* tape : tape_set) {
|
for (TFE_Py_Tape* tape : tape_set) {
|
||||||
if (tape->tape->ShouldRecord(tensor_ids)) {
|
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
|
||||||
Py_RETURN_TRUE;
|
Py_RETURN_TRUE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1169,8 +1190,26 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
|
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
|
||||||
|
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
||||||
|
if (seq == nullptr) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
int len = PySequence_Fast_GET_SIZE(seq);
|
||||||
|
std::vector<tensorflow::DataType> list;
|
||||||
|
list.reserve(len);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
|
||||||
|
list.push_back(FastTensorDtype(tensor));
|
||||||
|
}
|
||||||
|
Py_DECREF(seq);
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TapeSetRecordOperation(
|
||||||
|
PyObject* op_type, PyObject* output_tensors,
|
||||||
const std::vector<tensorflow::int64>& input_ids,
|
const std::vector<tensorflow::int64>& input_ids,
|
||||||
|
const std::vector<tensorflow::DataType>& input_dtypes,
|
||||||
PyObject* backward_function) {
|
PyObject* backward_function) {
|
||||||
std::vector<tensorflow::eager::TapeTensor> output_info;
|
std::vector<tensorflow::eager::TapeTensor> output_info;
|
||||||
PyObject* seq = PySequence_Fast(output_tensors,
|
PyObject* seq = PySequence_Fast(output_tensors,
|
||||||
@ -1206,7 +1245,7 @@ void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
|
|||||||
for (TFE_Py_Tape* tape : SafeTapeSet()) {
|
for (TFE_Py_Tape* tape : SafeTapeSet()) {
|
||||||
Py_INCREF(backward_function);
|
Py_INCREF(backward_function);
|
||||||
tape->tape->RecordOperation(
|
tape->tape->RecordOperation(
|
||||||
op_type_str, output_info, input_ids, backward_function,
|
op_type_str, output_info, input_ids, input_dtypes, backward_function,
|
||||||
[backward_function]() { Py_DECREF(backward_function); });
|
[backward_function]() { Py_DECREF(backward_function); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1221,7 +1260,11 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
|
|||||||
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
|
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
|
||||||
if (PyErr_Occurred()) return;
|
if (PyErr_Occurred()) return;
|
||||||
|
|
||||||
TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function);
|
std::vector<tensorflow::DataType> input_dtypes =
|
||||||
|
MakeTensorDtypeList(input_tensors);
|
||||||
|
if (PyErr_Occurred()) return;
|
||||||
|
TapeSetRecordOperation(op_type, output_tensors, input_ids, input_dtypes,
|
||||||
|
backward_function);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
|
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
|
||||||
@ -1710,10 +1753,12 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
|||||||
PyObject* results, PyObject* name) {
|
PyObject* results, PyObject* name) {
|
||||||
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
|
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
|
||||||
if (PyErr_Occurred()) return nullptr;
|
if (PyErr_Occurred()) return nullptr;
|
||||||
|
std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
|
||||||
|
if (PyErr_Occurred()) return nullptr;
|
||||||
|
|
||||||
bool should_record = false;
|
bool should_record = false;
|
||||||
for (TFE_Py_Tape* tape : SafeTapeSet()) {
|
for (TFE_Py_Tape* tape : SafeTapeSet()) {
|
||||||
if (tape->tape->ShouldRecord(input_ids)) {
|
if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
|
||||||
should_record = true;
|
should_record = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -1744,7 +1789,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
|||||||
Py_DECREF(callback_args);
|
Py_DECREF(callback_args);
|
||||||
if (backward_function == nullptr) return nullptr;
|
if (backward_function == nullptr) return nullptr;
|
||||||
|
|
||||||
TapeSetRecordOperation(op_name, results, input_ids, backward_function);
|
TapeSetRecordOperation(op_name, results, input_ids, input_dtypes,
|
||||||
|
backward_function);
|
||||||
|
|
||||||
Py_DECREF(backward_function);
|
Py_DECREF(backward_function);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user