From e9204018c310f1ab53323f468c76a40b43ce92e8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= <vovannghia2409@gmail.com>
Date: Thu, 15 Oct 2020 21:47:40 +0700
Subject: [PATCH] Add Sub gradients

---
 tensorflow/c/eager/gradients_test.cc          | 97 +++++++++++++++++++
 .../c/experimental/gradients/math_grad.cc     | 39 ++++++++
 .../c/experimental/gradients/math_grad.h      |  1 +
 .../python/framework/experimental/math_ops.cc | 11 +++
 .../python/framework/experimental/math_ops.py |  5 +
 .../python/framework/experimental/tape.cc     |  1 +
 .../experimental/unified_api_test.py          | 49 ++++++++++
 7 files changed, 203 insertions(+)

diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc
index fa775cfa0e4..7fafd6eaa07 100644
--- a/tensorflow/c/eager/gradients_test.cc
+++ b/tensorflow/c/eager/gradients_test.cc
@@ -63,6 +63,7 @@ Status RegisterGradients(GradientRegistry* registry) {
   TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
   TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
   TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
+  TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
   return Status::OK();
 }
 
@@ -231,6 +232,41 @@ Status NegGradModel(AbstractContext* ctx,
   return Status::OK();
 }
 
+// Computes
+// y = inputs[0] - inputs[1]
+// return grad(y, {inputs[0], inputs[1]})
+Status SubGradModel(AbstractContext* ctx,
+                    absl::Span<AbstractTensorHandle* const> inputs,
+                    absl::Span<AbstractTensorHandle*> outputs,
+                    const GradientRegistry& registry) {
+  TapeVSpace vspace(ctx);
+  auto tape = new Tape(/*persistent=*/false);
+  tape->Watch(ToId(inputs[0]));  // Watch x.
+  tape->Watch(ToId(inputs[1]));  // Watch y.
+  std::vector<AbstractTensorHandle*> sub_outputs(1);
+  AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
+  TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
+                              absl::MakeSpan(sub_outputs),
+                              "Sub"));  // Compute x-y.
+  std::unordered_map<tensorflow::int64, TapeTensor>
+      source_tensors_that_are_targets;
+
+  std::vector<AbstractTensorHandle*> out_grads;
+  TF_RETURN_IF_ERROR(tape->ComputeGradient(
+      vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
+      /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
+      source_tensors_that_are_targets,
+      /*output_gradients=*/{}, &out_grads,
+      /*build_default_zeros_grads=*/false));
+  for (auto sub_output : sub_outputs) {
+    sub_output->Unref();
+  }
+  outputs[0] = out_grads[0];
+  outputs[1] = out_grads[1];
+  delete tape;
+  return Status::OK();
+}
+
 AbstractContext* BuildFunction(const char* fn_name) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
@@ -612,6 +648,67 @@ TEST_P(CppGradients, TestNegGrad) {
   result_tensor = nullptr;
 }
 
+TEST_P(CppGradients, TestSubGrad) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  AbstractTensorHandlePtr x;
+  {
+    AbstractTensorHandle* x_raw = nullptr;
+    Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    x.reset(x_raw);
+  }
+
+  AbstractTensorHandlePtr y;
+  {
+    AbstractTensorHandle* y_raw = nullptr;
+    Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    y.reset(y_raw);
+  }
+
+  GradientRegistry registry;
+  Status s = RegisterGradients(&registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  // Pseudo-code:
+  //
+  // tape.watch(x)
+  // tape.watch(y)
+  // y = x - y
+  // outputs = tape.gradient(y, [x, y])
+  std::vector<AbstractTensorHandle*> outputs(2);
+  s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
+               absl::MakeSpan(outputs),
+               /*use_function=*/!std::get<2>(GetParam()), registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  TF_Tensor* result_tensor;
+  s = getValue(outputs[0], &result_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+  auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
+  EXPECT_EQ(*result_value, 1.0);
+  outputs[0]->Unref();
+  TF_DeleteTensor(result_tensor);
+  result_tensor = nullptr;
+
+  s = getValue(outputs[1], &result_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+  result_value = static_cast<float*>(TF_TensorData(result_tensor));
+  EXPECT_EQ(*result_value, -1.0);
+  outputs[1]->Unref();
+  TF_DeleteTensor(result_tensor);
+}
+
 TEST_P(CppGradients, TestSetAttrString) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc
index 13517e6ad90..b9d8bd15c1a 100644
--- a/tensorflow/c/experimental/gradients/math_grad.cc
+++ b/tensorflow/c/experimental/gradients/math_grad.cc
@@ -221,6 +221,36 @@ class NegGradientFunction : public GradientFunction {
   ~NegGradientFunction() override {}
 };
 
+class SubGradientFunction : public GradientFunction {
+ public:
+  Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
+                 vector<AbstractTensorHandle*>* grad_outputs) override {
+    /* Given upstream grad U and a Sub op A-B, the gradients are:
+     *
+     *    dA =  U
+     *    dB = -U
+     *
+     */
+
+    grad_outputs->resize(2);
+
+    // Grad for A
+    DCHECK(grad_inputs[0]);
+    (*grad_outputs)[0] = grad_inputs[0];
+
+    // Grad for B
+    // negate the upstream grad
+    std::vector<AbstractTensorHandle*> neg_outputs(1);
+    std::string name = "Neg_Sub_Grad_B";
+    TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
+                                absl::MakeSpan(neg_outputs), name.c_str()));
+    (*grad_outputs)[1] = neg_outputs[0];
+
+    return Status::OK();
+  }
+  ~SubGradientFunction() override {}
+};
+
 }  // namespace
 
 BackwardFunction* AddRegisterer(const ForwardOperation& op) {
@@ -268,5 +298,14 @@ BackwardFunction* NegRegisterer(const ForwardOperation& op) {
   return new BackwardFunction(gradient_function, default_gradients);
 }
 
+BackwardFunction* SubRegisterer(const ForwardOperation& op) {
+  // For ops with a single output, the gradient function is not called if there
+  // is no incoming gradient. So we do not need to worry about creating zeros
+  // grads in this case.
+  auto gradient_function = new SubGradientFunction;
+  auto default_gradients = new PassThroughDefaultGradients(op);
+  return new BackwardFunction(gradient_function, default_gradients);
+}
+
 }  // namespace gradients
 }  // namespace tensorflow
diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h
index 38d83b959d2..756c5f84153 100644
--- a/tensorflow/c/experimental/gradients/math_grad.h
+++ b/tensorflow/c/experimental/gradients/math_grad.h
@@ -25,6 +25,7 @@ BackwardFunction* ExpRegisterer(const ForwardOperation& op);
 BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
 BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
 BackwardFunction* NegRegisterer(const ForwardOperation& op);
+BackwardFunction* SubRegisterer(const ForwardOperation& op);
 
 }  // namespace gradients
 }  // namespace tensorflow
diff --git a/tensorflow/python/framework/experimental/math_ops.cc b/tensorflow/python/framework/experimental/math_ops.cc
index 7f2f4f28318..5e9522f1999 100644
--- a/tensorflow/python/framework/experimental/math_ops.cc
+++ b/tensorflow/python/framework/experimental/math_ops.cc
@@ -64,5 +64,16 @@ PYBIND11_MODULE(_math_ops, m) {
               ops::Neg(ctx, {a}, absl::MakeSpan(outputs), name));
           return outputs[0];
         });
+  m.def("sub", [](AbstractContext* ctx, AbstractTensorHandle* a,
+                  AbstractTensorHandle* b, const char* name) {
+    int num_outputs = 1;
+    std::vector<AbstractTensorHandle*> outputs(1);
+    if (!name) {
+      name = "Sub";
+    }
+    MaybeRaiseRegisteredFromStatus(
+        ops::Sub(ctx, {a, b}, absl::MakeSpan(outputs), name));
+    return outputs[0];
+  });
 }
 }  // namespace tensorflow
diff --git a/tensorflow/python/framework/experimental/math_ops.py b/tensorflow/python/framework/experimental/math_ops.py
index d5656e99319..879cddfa036 100644
--- a/tensorflow/python/framework/experimental/math_ops.py
+++ b/tensorflow/python/framework/experimental/math_ops.py
@@ -35,3 +35,8 @@ def mat_mul(a, b, name=None):
 def neg(a, name=None):
   ctx = context.get_default()
   return _math_ops.neg(ctx, a, name)
+
+
+def sub(a, b, name=None):
+  ctx = context.get_default()
+  return _math_ops.sub(ctx, a, b, name)
diff --git a/tensorflow/python/framework/experimental/tape.cc b/tensorflow/python/framework/experimental/tape.cc
index 5adb028aa55..a6975c085ac 100644
--- a/tensorflow/python/framework/experimental/tape.cc
+++ b/tensorflow/python/framework/experimental/tape.cc
@@ -37,6 +37,7 @@ Status RegisterGradients(GradientRegistry* registry) {
       registry->Register("SparseSoftmaxCrossEntropyWithLogits",
                          SparseSoftmaxCrossEntropyWithLogitsRegisterer));
   TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
+  TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
   return Status::OK();
 }
 
diff --git a/tensorflow/python/framework/experimental/unified_api_test.py b/tensorflow/python/framework/experimental/unified_api_test.py
index 9d02d51cf9a..8edb3f51f7a 100644
--- a/tensorflow/python/framework/experimental/unified_api_test.py
+++ b/tensorflow/python/framework/experimental/unified_api_test.py
@@ -207,6 +207,55 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
       eager_outputs = model(a)
       self.assertAllEqual(eager_outputs.numpy(), [-1.0])
 
+  @parameterized.named_parameters([
+      ("Graph", False),
+      ("Mlir", True),
+  ])
+  def testSub(self, use_mlir):
+    if use_mlir:
+      SetTracingImplementation("mlir")
+
+    def model(a, b):
+      return unified_math_ops.sub(a, b)
+
+    with context_lib.set_default(get_immediate_execution_context()):
+      a = TensorCastHelper(constant_op.constant([1., 2.]))
+      b = TensorCastHelper(constant_op.constant([3., 4.]))
+
+      func_output = def_function.function(model)(a, b)
+      self.assertAllEqual(func_output.numpy(), [-2., -2.])
+
+      eager_output = model(a, b)
+      self.assertAllEqual(eager_output.numpy(), [-2., -2.])
+
+  @parameterized.named_parameters([
+      ("Graph", False),
+      ("Mlir", True),
+  ])
+  def testSubGrad(self, use_mlir):
+    if use_mlir:
+      SetTracingImplementation("mlir")
+
+    def model(a, b):
+      with tape_lib.GradientTape() as tape:
+        tape.watch(a)
+        tape.watch(b)
+        result = unified_math_ops.sub(a, b)
+      grads = tape.gradient(result, [a, b])
+      return grads
+
+    with context_lib.set_default(get_immediate_execution_context()):
+      a = TensorCastHelper(constant_op.constant([1., 2.]))
+      b = TensorCastHelper(constant_op.constant([3., 4.]))
+
+      func_outputs = def_function.function(model)(a, b)
+      self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0])
+      self.assertAllEqual(func_outputs[1].numpy(), [-1.0, -1.0])
+
+      eager_outputs = model(a, b)
+      self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0])
+      self.assertAllEqual(eager_outputs[1].numpy(), [-1.0, -1.0])
+
 
 class UnifiedTapeBenchmark(test.Benchmark):