Fix unstable test case for Select op (#15807)
* Revert "Revert "C++ gradient for Select (#14862)" (#15764)"
This reverts commit 4c19f77d2a
.
This commit is contained in:
parent
38976778c5
commit
6db014b448
@ -763,6 +763,24 @@ Status LgammaGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
|
||||
|
||||
Status SelectGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto comparator = op.input(0);
|
||||
auto x = op.input(1);
|
||||
auto zeros = ZerosLike(scope, x);
|
||||
auto grad = grad_inputs[0];
|
||||
|
||||
auto gx_1 = Where3(scope, comparator, grad, zeros);
|
||||
auto gx_2 = Where3(scope, comparator, zeros, grad);
|
||||
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(gx_1);
|
||||
grad_outputs->push_back(gx_2);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Select", SelectGrad);
|
||||
|
||||
Status MinOrMaxGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
|
@ -882,5 +882,17 @@ TEST_F(NaryGradTest, Prod) {
|
||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, Select) {
|
||||
TensorShape shape({3, 2});
|
||||
auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
// Use constant values to avoid instability when computing
|
||||
Tensor c =
|
||||
test::AsTensor<float>({-3.5f, 1.5f, -1.2f, 3.0f, -2.5f, 2.8f}, {3, 2});
|
||||
auto zero = Cast(scope_, Const(scope_, 0.0), c.dtype());
|
||||
auto y = Where3(scope_, Greater(scope_, c, zero), x1, x2);
|
||||
RunTest({x1, x2}, {shape, shape}, {y}, {shape});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user