Add unified_api_test for neg gradients

This commit is contained in:
Võ Văn Nghĩa 2020-10-14 00:44:46 +07:00
parent 8897653779
commit 604fa5673e
3 changed files with 46 additions and 6 deletions

View File

@ -213,14 +213,9 @@ class NegGradientFunction : public GradientFunction {
*/
grad_outputs->resize(1);
// Grad for X
std::vector<AbstractTensorHandle*> neg_outputs(1);
std::string name = "Neg_Grad";
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(neg_outputs), name.c_str()));
(*grad_outputs)[0] = neg_outputs[0];
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~NegGradientFunction() override {}

View File

@ -36,6 +36,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
return Status::OK();
}

View File

@ -163,6 +163,50 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
eager_output = model(negative)
self.assertAllEqual(eager_output.numpy(), [0.])
@parameterized.named_parameters([
("Graph", False),
("Mlir", True),
])
def testNeg(self, use_mlir):
if use_mlir:
SetTracingImplementation("mlir")
def model(a):
return unified_math_ops.neg(a)
with context_lib.set_default(get_immediate_execution_context()):
a = TensorCastHelper(constant_op.constant([2.]))
func_output = def_function.function(model)(a)
self.assertAllEqual(func_output.numpy(), [-2.])
eager_output = model(a)
self.assertAllEqual(eager_output.numpy(), [-2.])
@parameterized.named_parameters([
("Graph", False),
("Mlir", True),
])
def testNegGrad(self, use_mlir):
if use_mlir:
SetTracingImplementation("mlir")
def model(a):
with tape_lib.GradientTape() as tape:
tape.watch(a)
result = unified_math_ops.neg(a)
grads = tape.gradient(result, a)
return grads
with context_lib.set_default(get_immediate_execution_context()):
a = TensorCastHelper(constant_op.constant([2.]))
func_outputs = def_function.function(model)(a)
self.assertAllEqual(func_outputs.numpy(), [-1.0])
eager_outputs = model(a)
self.assertAllEqual(eager_outputs.numpy(), [-1.0])
class UnifiedTapeBenchmark(test.Benchmark):