Fix unstable test case for Select op (#15807)

* Revert "Revert "C++ gradient for Select (#14862)" (#15764)"

This reverts commit 4c19f77d2a.
This commit is contained in:
Yan Facai (颜发才) 2018-01-04 10:13:18 +08:00 committed by Shanqing Cai
parent 38976778c5
commit 6db014b448
2 changed files with 30 additions and 0 deletions

View File

@ -763,6 +763,24 @@ Status LgammaGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
Status SelectGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto comparator = op.input(0);
auto x = op.input(1);
auto zeros = ZerosLike(scope, x);
auto grad = grad_inputs[0];
auto gx_1 = Where3(scope, comparator, grad, zeros);
auto gx_2 = Where3(scope, comparator, zeros, grad);
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(gx_1);
grad_outputs->push_back(gx_2);
return scope.status();
}
REGISTER_GRADIENT_OP("Select", SelectGrad);
Status MinOrMaxGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {

View File

@ -882,5 +882,17 @@ TEST_F(NaryGradTest, Prod) {
RunTest({x}, {x_shape}, {y}, {y_shape});
}
TEST_F(NaryGradTest, Select) {
TensorShape shape({3, 2});
auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
// Use constant values to avoid instability when computing
Tensor c =
test::AsTensor<float>({-3.5f, 1.5f, -1.2f, 3.0f, -2.5f, 2.8f}, {3, 2});
auto zero = Cast(scope_, Const(scope_, 0.0), c.dtype());
auto y = Where3(scope_, Greater(scope_, c, zero), x1, x2);
RunTest({x1, x2}, {shape, shape}, {y}, {shape});
}
} // namespace
} // namespace tensorflow