diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 52c177212a8..afd92fbf488 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -763,6 +763,24 @@ Status LgammaGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Lgamma", LgammaGrad); +Status SelectGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto comparator = op.input(0); + auto x = op.input(1); + auto zeros = ZerosLike(scope, x); + auto grad = grad_inputs[0]; + + auto gx_1 = Where3(scope, comparator, grad, zeros); + auto gx_2 = Where3(scope, comparator, zeros, grad); + + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(gx_1); + grad_outputs->push_back(gx_2); + return scope.status(); +} +REGISTER_GRADIENT_OP("Select", SelectGrad); + Status MinOrMaxGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index dde4b700601..79148843a88 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -882,5 +882,17 @@ TEST_F(NaryGradTest, Prod) { RunTest({x}, {x_shape}, {y}, {y_shape}); } +TEST_F(NaryGradTest, Select) { + TensorShape shape({3, 2}); + auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // Use constant values to avoid instability when computing + Tensor c = + test::AsTensor({-3.5f, 1.5f, -1.2f, 3.0f, -2.5f, 2.8f}, {3, 2}); + auto zero = Cast(scope_, Const(scope_, 0.0), c.dtype()); + auto y = Where3(scope_, Greater(scope_, c, zero), x1, x2); + RunTest({x1, x2}, {shape, shape}, {y}, {shape}); +} + } // namespace } // namespace tensorflow