Add Sub gradients
This commit is contained in:
parent
604fa5673e
commit
e9204018c3
@ -63,6 +63,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
|||||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
||||||
|
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,6 +232,41 @@ Status NegGradModel(AbstractContext* ctx,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes
|
||||||
|
// y = inputs[0] - inputs[1]
|
||||||
|
// return grad(y, {inputs[0], inputs[1]})
|
||||||
|
Status SubGradModel(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs,
|
||||||
|
const GradientRegistry& registry) {
|
||||||
|
TapeVSpace vspace(ctx);
|
||||||
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
|
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||||
|
std::vector<AbstractTensorHandle*> sub_outputs(1);
|
||||||
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
|
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
|
||||||
|
absl::MakeSpan(sub_outputs),
|
||||||
|
"Sub")); // Compute x-y.
|
||||||
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
|
source_tensors_that_are_targets;
|
||||||
|
|
||||||
|
std::vector<AbstractTensorHandle*> out_grads;
|
||||||
|
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||||
|
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
|
||||||
|
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||||
|
source_tensors_that_are_targets,
|
||||||
|
/*output_gradients=*/{}, &out_grads,
|
||||||
|
/*build_default_zeros_grads=*/false));
|
||||||
|
for (auto sub_output : sub_outputs) {
|
||||||
|
sub_output->Unref();
|
||||||
|
}
|
||||||
|
outputs[0] = out_grads[0];
|
||||||
|
outputs[1] = out_grads[1];
|
||||||
|
delete tape;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
AbstractContext* BuildFunction(const char* fn_name) {
|
AbstractContext* BuildFunction(const char* fn_name) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
@ -612,6 +648,67 @@ TEST_P(CppGradients, TestNegGrad) {
|
|||||||
result_tensor = nullptr;
|
result_tensor = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(CppGradients, TestSubGrad) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
AbstractContextPtr ctx;
|
||||||
|
{
|
||||||
|
AbstractContext* ctx_raw = nullptr;
|
||||||
|
Status s =
|
||||||
|
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
ctx.reset(ctx_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensorHandlePtr x;
|
||||||
|
{
|
||||||
|
AbstractTensorHandle* x_raw = nullptr;
|
||||||
|
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
x.reset(x_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensorHandlePtr y;
|
||||||
|
{
|
||||||
|
AbstractTensorHandle* y_raw = nullptr;
|
||||||
|
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
y.reset(y_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
GradientRegistry registry;
|
||||||
|
Status s = RegisterGradients(®istry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
// Pseudo-code:
|
||||||
|
//
|
||||||
|
// tape.watch(x)
|
||||||
|
// tape.watch(y)
|
||||||
|
// y = x - y
|
||||||
|
// outputs = tape.gradient(y, [x, y])
|
||||||
|
std::vector<AbstractTensorHandle*> outputs(2);
|
||||||
|
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
|
||||||
|
absl::MakeSpan(outputs),
|
||||||
|
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
TF_Tensor* result_tensor;
|
||||||
|
s = getValue(outputs[0], &result_tensor);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||||
|
EXPECT_EQ(*result_value, 1.0);
|
||||||
|
outputs[0]->Unref();
|
||||||
|
TF_DeleteTensor(result_tensor);
|
||||||
|
result_tensor = nullptr;
|
||||||
|
|
||||||
|
s = getValue(outputs[1], &result_tensor);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||||
|
EXPECT_EQ(*result_value, -1.0);
|
||||||
|
outputs[1]->Unref();
|
||||||
|
TF_DeleteTensor(result_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(CppGradients, TestSetAttrString) {
|
TEST_P(CppGradients, TestSetAttrString) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
@ -221,6 +221,36 @@ class NegGradientFunction : public GradientFunction {
|
|||||||
~NegGradientFunction() override {}
|
~NegGradientFunction() override {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SubGradientFunction : public GradientFunction {
|
||||||
|
public:
|
||||||
|
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||||
|
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||||
|
/* Given upstream grad U and a Sub op A-B, the gradients are:
|
||||||
|
*
|
||||||
|
* dA = U
|
||||||
|
* dB = -U
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
grad_outputs->resize(2);
|
||||||
|
|
||||||
|
// Grad for A
|
||||||
|
DCHECK(grad_inputs[0]);
|
||||||
|
(*grad_outputs)[0] = grad_inputs[0];
|
||||||
|
|
||||||
|
// Grad for B
|
||||||
|
// negate the upstream grad
|
||||||
|
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||||
|
std::string name = "Neg_Sub_Grad_B";
|
||||||
|
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||||
|
absl::MakeSpan(neg_outputs), name.c_str()));
|
||||||
|
(*grad_outputs)[1] = neg_outputs[0];
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
~SubGradientFunction() override {}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
||||||
@ -268,5 +298,14 @@ BackwardFunction* NegRegisterer(const ForwardOperation& op) {
|
|||||||
return new BackwardFunction(gradient_function, default_gradients);
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
|
||||||
|
// For ops with a single output, the gradient function is not called if there
|
||||||
|
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||||
|
// grads in this case.
|
||||||
|
auto gradient_function = new SubGradientFunction;
|
||||||
|
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||||
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -25,6 +25,7 @@ BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
|||||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* NegRegisterer(const ForwardOperation& op);
|
BackwardFunction* NegRegisterer(const ForwardOperation& op);
|
||||||
|
BackwardFunction* SubRegisterer(const ForwardOperation& op);
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -64,5 +64,16 @@ PYBIND11_MODULE(_math_ops, m) {
|
|||||||
ops::Neg(ctx, {a}, absl::MakeSpan(outputs), name));
|
ops::Neg(ctx, {a}, absl::MakeSpan(outputs), name));
|
||||||
return outputs[0];
|
return outputs[0];
|
||||||
});
|
});
|
||||||
|
m.def("sub", [](AbstractContext* ctx, AbstractTensorHandle* a,
|
||||||
|
AbstractTensorHandle* b, const char* name) {
|
||||||
|
int num_outputs = 1;
|
||||||
|
std::vector<AbstractTensorHandle*> outputs(1);
|
||||||
|
if (!name) {
|
||||||
|
name = "Sub";
|
||||||
|
}
|
||||||
|
MaybeRaiseRegisteredFromStatus(
|
||||||
|
ops::Sub(ctx, {a, b}, absl::MakeSpan(outputs), name));
|
||||||
|
return outputs[0];
|
||||||
|
});
|
||||||
}
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -35,3 +35,8 @@ def mat_mul(a, b, name=None):
|
|||||||
def neg(a, name=None):
|
def neg(a, name=None):
|
||||||
ctx = context.get_default()
|
ctx = context.get_default()
|
||||||
return _math_ops.neg(ctx, a, name)
|
return _math_ops.neg(ctx, a, name)
|
||||||
|
|
||||||
|
|
||||||
|
def sub(a, b, name=None):
|
||||||
|
ctx = context.get_default()
|
||||||
|
return _math_ops.sub(ctx, a, b, name)
|
||||||
|
@ -37,6 +37,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
|||||||
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||||
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
|
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
||||||
|
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,6 +207,55 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
|
|||||||
eager_outputs = model(a)
|
eager_outputs = model(a)
|
||||||
self.assertAllEqual(eager_outputs.numpy(), [-1.0])
|
self.assertAllEqual(eager_outputs.numpy(), [-1.0])
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
("Graph", False),
|
||||||
|
("Mlir", True),
|
||||||
|
])
|
||||||
|
def testSub(self, use_mlir):
|
||||||
|
if use_mlir:
|
||||||
|
SetTracingImplementation("mlir")
|
||||||
|
|
||||||
|
def model(a, b):
|
||||||
|
return unified_math_ops.sub(a, b)
|
||||||
|
|
||||||
|
with context_lib.set_default(get_immediate_execution_context()):
|
||||||
|
a = TensorCastHelper(constant_op.constant([1., 2.]))
|
||||||
|
b = TensorCastHelper(constant_op.constant([3., 4.]))
|
||||||
|
|
||||||
|
func_output = def_function.function(model)(a, b)
|
||||||
|
self.assertAllEqual(func_output.numpy(), [-2., -2.])
|
||||||
|
|
||||||
|
eager_output = model(a, b)
|
||||||
|
self.assertAllEqual(eager_output.numpy(), [-2., -2.])
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
("Graph", False),
|
||||||
|
("Mlir", True),
|
||||||
|
])
|
||||||
|
def testSubGrad(self, use_mlir):
|
||||||
|
if use_mlir:
|
||||||
|
SetTracingImplementation("mlir")
|
||||||
|
|
||||||
|
def model(a, b):
|
||||||
|
with tape_lib.GradientTape() as tape:
|
||||||
|
tape.watch(a)
|
||||||
|
tape.watch(b)
|
||||||
|
result = unified_math_ops.sub(a, b)
|
||||||
|
grads = tape.gradient(result, [a, b])
|
||||||
|
return grads
|
||||||
|
|
||||||
|
with context_lib.set_default(get_immediate_execution_context()):
|
||||||
|
a = TensorCastHelper(constant_op.constant([1., 2.]))
|
||||||
|
b = TensorCastHelper(constant_op.constant([3., 4.]))
|
||||||
|
|
||||||
|
func_outputs = def_function.function(model)(a, b)
|
||||||
|
self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0])
|
||||||
|
self.assertAllEqual(func_outputs[1].numpy(), [-1.0, -1.0])
|
||||||
|
|
||||||
|
eager_outputs = model(a, b)
|
||||||
|
self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0])
|
||||||
|
self.assertAllEqual(eager_outputs[1].numpy(), [-1.0, -1.0])
|
||||||
|
|
||||||
|
|
||||||
class UnifiedTapeBenchmark(test.Benchmark):
|
class UnifiedTapeBenchmark(test.Benchmark):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user