Add Sub gradients
This commit is contained in:
parent
604fa5673e
commit
e9204018c3
@ -63,6 +63,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -231,6 +232,41 @@ Status NegGradModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] - inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status SubGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> sub_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(sub_outputs),
|
||||
"Sub")); // Compute x-y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto sub_output : sub_outputs) {
|
||||
sub_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -612,6 +648,67 @@ TEST_P(CppGradients, TestNegGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSubGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x - y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, -1.0);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSetAttrString) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
@ -221,6 +221,36 @@ class NegGradientFunction : public GradientFunction {
|
||||
~NegGradientFunction() override {}
|
||||
};
|
||||
|
||||
class SubGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a Sub op A-B, the gradients are:
|
||||
*
|
||||
* dA = U
|
||||
* dB = -U
|
||||
*
|
||||
*/
|
||||
|
||||
grad_outputs->resize(2);
|
||||
|
||||
// Grad for A
|
||||
DCHECK(grad_inputs[0]);
|
||||
(*grad_outputs)[0] = grad_inputs[0];
|
||||
|
||||
// Grad for B
|
||||
// negate the upstream grad
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
std::string name = "Neg_Sub_Grad_B";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(neg_outputs), name.c_str()));
|
||||
(*grad_outputs)[1] = neg_outputs[0];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~SubGradientFunction() override {}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
@ -268,5 +298,14 @@ BackwardFunction* NegRegisterer(const ForwardOperation& op) {
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new SubGradientFunction;
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -25,6 +25,7 @@ BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -64,5 +64,16 @@ PYBIND11_MODULE(_math_ops, m) {
|
||||
ops::Neg(ctx, {a}, absl::MakeSpan(outputs), name));
|
||||
return outputs[0];
|
||||
});
|
||||
m.def("sub", [](AbstractContext* ctx, AbstractTensorHandle* a,
|
||||
AbstractTensorHandle* b, const char* name) {
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
if (!name) {
|
||||
name = "Sub";
|
||||
}
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
ops::Sub(ctx, {a, b}, absl::MakeSpan(outputs), name));
|
||||
return outputs[0];
|
||||
});
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -35,3 +35,8 @@ def mat_mul(a, b, name=None):
|
||||
def neg(a, name=None):
|
||||
ctx = context.get_default()
|
||||
return _math_ops.neg(ctx, a, name)
|
||||
|
||||
|
||||
def sub(a, b, name=None):
|
||||
ctx = context.get_default()
|
||||
return _math_ops.sub(ctx, a, b, name)
|
||||
|
@ -37,6 +37,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -207,6 +207,55 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
|
||||
eager_outputs = model(a)
|
||||
self.assertAllEqual(eager_outputs.numpy(), [-1.0])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("Graph", False),
|
||||
("Mlir", True),
|
||||
])
|
||||
def testSub(self, use_mlir):
|
||||
if use_mlir:
|
||||
SetTracingImplementation("mlir")
|
||||
|
||||
def model(a, b):
|
||||
return unified_math_ops.sub(a, b)
|
||||
|
||||
with context_lib.set_default(get_immediate_execution_context()):
|
||||
a = TensorCastHelper(constant_op.constant([1., 2.]))
|
||||
b = TensorCastHelper(constant_op.constant([3., 4.]))
|
||||
|
||||
func_output = def_function.function(model)(a, b)
|
||||
self.assertAllEqual(func_output.numpy(), [-2., -2.])
|
||||
|
||||
eager_output = model(a, b)
|
||||
self.assertAllEqual(eager_output.numpy(), [-2., -2.])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("Graph", False),
|
||||
("Mlir", True),
|
||||
])
|
||||
def testSubGrad(self, use_mlir):
|
||||
if use_mlir:
|
||||
SetTracingImplementation("mlir")
|
||||
|
||||
def model(a, b):
|
||||
with tape_lib.GradientTape() as tape:
|
||||
tape.watch(a)
|
||||
tape.watch(b)
|
||||
result = unified_math_ops.sub(a, b)
|
||||
grads = tape.gradient(result, [a, b])
|
||||
return grads
|
||||
|
||||
with context_lib.set_default(get_immediate_execution_context()):
|
||||
a = TensorCastHelper(constant_op.constant([1., 2.]))
|
||||
b = TensorCastHelper(constant_op.constant([3., 4.]))
|
||||
|
||||
func_outputs = def_function.function(model)(a, b)
|
||||
self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0])
|
||||
self.assertAllEqual(func_outputs[1].numpy(), [-1.0, -1.0])
|
||||
|
||||
eager_outputs = model(a, b)
|
||||
self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0])
|
||||
self.assertAllEqual(eager_outputs[1].numpy(), [-1.0, -1.0])
|
||||
|
||||
|
||||
class UnifiedTapeBenchmark(test.Benchmark):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user