Make C++ gradients implementation use zero gradients for bool tensor edges.
This solves a problem where the generated gradients for Select cannot be executed because the gradients are summed using an AddN op which does not support bool data. (See test case.) This code path should not be relevant for most new code, since the Python gradients code is preferred, but it can come up in older code that uses Defun. PiperOrigin-RevId: 245967608
This commit is contained in:
parent
7182f4ed6f
commit
593657998c
@ -1640,9 +1640,14 @@ cc_library(
|
||||
] + if_dynamic_kernels(
|
||||
[],
|
||||
otherwise = [
|
||||
"//tensorflow/core/kernels:aggregate_ops",
|
||||
"//tensorflow/core/kernels:bcast_ops",
|
||||
"//tensorflow/core/kernels:cast_op",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:random_ops",
|
||||
"//tensorflow/core/kernels:reduction_ops",
|
||||
"//tensorflow/core/kernels:reshape_op",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
@ -1439,6 +1439,58 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, Gradient_Select) {
|
||||
FunctionDef my_select = FunctionDefHelper::Create(
|
||||
"MySelect",
|
||||
// Args
|
||||
{"condition: bool", "t: float32", "e: float32"},
|
||||
// Return values
|
||||
{"z: float32"},
|
||||
// Attrs
|
||||
{},
|
||||
// Nodes
|
||||
{
|
||||
{{"select0"}, "Select", {"condition", "t", "e"}, {{"T", DT_FLOAT}}},
|
||||
{{"select1"}, "Select", {"condition", "t", "e"}, {{"T", DT_FLOAT}}},
|
||||
{{"add"},
|
||||
"Add",
|
||||
{"select0:output", "select1:output"},
|
||||
{{"T", DT_FLOAT}}},
|
||||
},
|
||||
// Output mapping
|
||||
{{"z", "add:z"}});
|
||||
FunctionDef select_grad = FunctionDefHelper::Create(
|
||||
"MySelectGrad",
|
||||
// Args
|
||||
{"condition: bool", "t:float32", "e: float32", "dz: float32"},
|
||||
// Return values
|
||||
{"dt: float32"},
|
||||
// Attrs
|
||||
{},
|
||||
// Nodes
|
||||
{{
|
||||
{"grad"},
|
||||
"SymbolicGradient",
|
||||
{"condition", "t", "e", "dz"},
|
||||
{
|
||||
{"f", FunctionDefHelper::FunctionRef("MySelect")},
|
||||
{"Tin", DataTypeSlice({DT_BOOL, DT_FLOAT, DT_FLOAT, DT_FLOAT})},
|
||||
{"Tout", DataTypeSlice({DT_BOOL, DT_FLOAT, DT_FLOAT})},
|
||||
},
|
||||
}},
|
||||
// Output mapping
|
||||
{{"dt", "grad:output:1"}});
|
||||
Init({my_select, select_grad});
|
||||
|
||||
auto condition = test::AsTensor<bool>({false});
|
||||
auto t = test::AsTensor<float>({13.0});
|
||||
auto e = test::AsTensor<float>({15.0});
|
||||
auto dz = test::AsTensor<float>({1.0});
|
||||
Tensor y;
|
||||
TF_EXPECT_OK(InstantiateAndRun(flr0_, "MySelectGrad", {},
|
||||
{condition, t, e, dz}, {&y}));
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) {
|
||||
Init({});
|
||||
auto T = DT_FLOAT;
|
||||
|
@ -198,6 +198,9 @@ class SymbolicGradientBuilder {
|
||||
void BackpropAlongEdge(const NodeOut& dst_grad, const NodeOut& src);
|
||||
void BackpropZerosAlongEdge(const NodeOut& src);
|
||||
|
||||
// Returns a node representing the sum of any backpropped gradients for 'src'.
|
||||
// This will be an AddN node if there is more than one accumulated gradient.
|
||||
// Returns zeros if there are no gradients, or the dtype is DT_BOOL.
|
||||
NodeOut SumGradients(const NodeOut& src);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder);
|
||||
@ -296,7 +299,7 @@ NodeOut SymbolicGradientBuilder::SumGradients(const NodeOut& src) {
|
||||
auto iter = backprops_.find(src);
|
||||
CHECK(iter != backprops_.end());
|
||||
const auto& grads = iter->second;
|
||||
if (grads.empty()) {
|
||||
if (grads.empty() || dtype == DT_BOOL) {
|
||||
// Nothing propagated back. The best we can come up is zeros.
|
||||
Node* zero_like = AddZerosLike(graph_, src);
|
||||
return {zero_like, 0};
|
||||
|
Loading…
Reference in New Issue
Block a user