Add Sub gradients

This commit is contained in:
Võ Văn Nghĩa 2020-10-15 21:47:40 +07:00
parent 604fa5673e
commit e9204018c3
7 changed files with 203 additions and 0 deletions

View File

@ -63,6 +63,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
return Status::OK();
}
@ -231,6 +232,41 @@ Status NegGradModel(AbstractContext* ctx,
return Status::OK();
}
// Computes
// y = inputs[0] - inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status SubGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> sub_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
absl::MakeSpan(sub_outputs),
"Sub")); // Compute x-y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto sub_output : sub_outputs) {
sub_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -612,6 +648,67 @@ TEST_P(CppGradients, TestNegGrad) {
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSubGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x - y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, -1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestSetAttrString) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);

View File

@ -221,6 +221,36 @@ class NegGradientFunction : public GradientFunction {
~NegGradientFunction() override {}
};
class SubGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a Sub op A-B, the gradients are:
*
* dA = U
* dB = -U
*
*/
grad_outputs->resize(2);
// Grad for A
DCHECK(grad_inputs[0]);
(*grad_outputs)[0] = grad_inputs[0];
// Grad for B
// negate the upstream grad
std::vector<AbstractTensorHandle*> neg_outputs(1);
std::string name = "Neg_Sub_Grad_B";
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(neg_outputs), name.c_str()));
(*grad_outputs)[1] = neg_outputs[0];
return Status::OK();
}
~SubGradientFunction() override {}
};
} // namespace
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
@ -268,5 +298,14 @@ BackwardFunction* NegRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new SubGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -25,6 +25,7 @@ BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
BackwardFunction* NegRegisterer(const ForwardOperation& op);
BackwardFunction* SubRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow

View File

@ -64,5 +64,16 @@ PYBIND11_MODULE(_math_ops, m) {
ops::Neg(ctx, {a}, absl::MakeSpan(outputs), name));
return outputs[0];
});
m.def("sub", [](AbstractContext* ctx, AbstractTensorHandle* a,
AbstractTensorHandle* b, const char* name) {
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(1);
if (!name) {
name = "Sub";
}
MaybeRaiseRegisteredFromStatus(
ops::Sub(ctx, {a, b}, absl::MakeSpan(outputs), name));
return outputs[0];
});
}
} // namespace tensorflow

View File

@ -35,3 +35,8 @@ def mat_mul(a, b, name=None):
def neg(a, name=None):
ctx = context.get_default()
return _math_ops.neg(ctx, a, name)
def sub(a, b, name=None):
ctx = context.get_default()
return _math_ops.sub(ctx, a, b, name)

View File

@ -37,6 +37,7 @@ Status RegisterGradients(GradientRegistry* registry) {
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
return Status::OK();
}

View File

@ -207,6 +207,55 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
eager_outputs = model(a)
self.assertAllEqual(eager_outputs.numpy(), [-1.0])
@parameterized.named_parameters([
("Graph", False),
("Mlir", True),
])
def testSub(self, use_mlir):
if use_mlir:
SetTracingImplementation("mlir")
def model(a, b):
return unified_math_ops.sub(a, b)
with context_lib.set_default(get_immediate_execution_context()):
a = TensorCastHelper(constant_op.constant([1., 2.]))
b = TensorCastHelper(constant_op.constant([3., 4.]))
func_output = def_function.function(model)(a, b)
self.assertAllEqual(func_output.numpy(), [-2., -2.])
eager_output = model(a, b)
self.assertAllEqual(eager_output.numpy(), [-2., -2.])
@parameterized.named_parameters([
("Graph", False),
("Mlir", True),
])
def testSubGrad(self, use_mlir):
if use_mlir:
SetTracingImplementation("mlir")
def model(a, b):
with tape_lib.GradientTape() as tape:
tape.watch(a)
tape.watch(b)
result = unified_math_ops.sub(a, b)
grads = tape.gradient(result, [a, b])
return grads
with context_lib.set_default(get_immediate_execution_context()):
a = TensorCastHelper(constant_op.constant([1., 2.]))
b = TensorCastHelper(constant_op.constant([3., 4.]))
func_outputs = def_function.function(model)(a, b)
self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0])
self.assertAllEqual(func_outputs[1].numpy(), [-1.0, -1.0])
eager_outputs = model(a, b)
self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0])
self.assertAllEqual(eager_outputs[1].numpy(), [-1.0, -1.0])
class UnifiedTapeBenchmark(test.Benchmark):