Add unified_api_test for neg gradients
This commit is contained in:
parent
8897653779
commit
604fa5673e
@ -213,14 +213,9 @@ class NegGradientFunction : public GradientFunction {
|
||||
*/
|
||||
|
||||
grad_outputs->resize(1);
|
||||
|
||||
// Grad for X
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
std::string name = "Neg_Grad";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(neg_outputs), name.c_str()));
|
||||
|
||||
(*grad_outputs)[0] = neg_outputs[0];
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~NegGradientFunction() override {}
|
||||
|
@ -36,6 +36,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -163,6 +163,50 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
|
||||
eager_output = model(negative)
|
||||
self.assertAllEqual(eager_output.numpy(), [0.])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("Graph", False),
|
||||
("Mlir", True),
|
||||
])
|
||||
def testNeg(self, use_mlir):
|
||||
if use_mlir:
|
||||
SetTracingImplementation("mlir")
|
||||
|
||||
def model(a):
|
||||
return unified_math_ops.neg(a)
|
||||
|
||||
with context_lib.set_default(get_immediate_execution_context()):
|
||||
a = TensorCastHelper(constant_op.constant([2.]))
|
||||
|
||||
func_output = def_function.function(model)(a)
|
||||
self.assertAllEqual(func_output.numpy(), [-2.])
|
||||
|
||||
eager_output = model(a)
|
||||
self.assertAllEqual(eager_output.numpy(), [-2.])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("Graph", False),
|
||||
("Mlir", True),
|
||||
])
|
||||
def testNegGrad(self, use_mlir):
|
||||
if use_mlir:
|
||||
SetTracingImplementation("mlir")
|
||||
|
||||
def model(a):
|
||||
with tape_lib.GradientTape() as tape:
|
||||
tape.watch(a)
|
||||
result = unified_math_ops.neg(a)
|
||||
grads = tape.gradient(result, a)
|
||||
return grads
|
||||
|
||||
with context_lib.set_default(get_immediate_execution_context()):
|
||||
a = TensorCastHelper(constant_op.constant([2.]))
|
||||
|
||||
func_outputs = def_function.function(model)(a)
|
||||
self.assertAllEqual(func_outputs.numpy(), [-1.0])
|
||||
|
||||
eager_outputs = model(a)
|
||||
self.assertAllEqual(eager_outputs.numpy(), [-1.0])
|
||||
|
||||
|
||||
class UnifiedTapeBenchmark(test.Benchmark):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user