Merge pull request #45798 from vnvo2409:gradients

PiperOrigin-RevId: 350605627
Change-Id: I81610a0971dc1962496f1e00e41399cf90173f8c
This commit is contained in:
TensorFlower Gardener 2021-01-07 11:42:31 -08:00
commit 3c4b13bcdb
6 changed files with 98 additions and 86 deletions

View File

@ -342,59 +342,6 @@ TEST_P(CppGradients, TestMatMulTranspose) {
}
}
TEST_P(CppGradients, TestReluGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
int64_t X_dims[] = {3, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(X)
* Y = Relu(X)
* outputs = tape.gradient(Y, [X])
*/
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(ReluGradModel, ctx.get(), {X.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dX_tensor;
s = GetValue(outputs[0], &dX_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[9] = {0};
memcpy(&result_data[0], TF_TensorData(dX_tensor),
TF_TensorByteSize(dX_tensor));
float expected_dX[9] = {1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
float tolerance = 1e-3;
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
}
outputs[0]->Unref();
TF_DeleteTensor(dX_tensor);
}
TEST_P(CppGradients, TestMNISTGrad) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {

View File

@ -159,30 +159,6 @@ Status MatMulTransposeModel(AbstractContext* ctx,
return Status::OK();
}
Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
auto tape = new Tape(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch X
vector<AbstractTensorHandle*> relu_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
absl::MakeSpan(relu_outputs),
"relu0")); // Relu(X)
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/relu_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto relu_output : relu_outputs) {
relu_output->Unref();
}
delete tape;
return Status::OK();
}
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,

View File

@ -61,12 +61,6 @@ Status MatMulTransposeModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify ReluGrad functionality
Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify Multi-grad functionality for MNIST
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,

View File

@ -74,6 +74,37 @@ void CompareNumericalAndAutodiffGradients(
}
}
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
absl::Span<const int64_t> dims, double abs_error) {
TF_Tensor* analytical_tensor;
auto s = GetValue(t, &analytical_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
int64_t num_elem_analytical = 1;
auto num_dims_analytical = TF_NumDims(analytical_tensor);
ASSERT_EQ(dims.size(), num_dims_analytical);
for (int j = 0; j < num_dims_analytical; j++) {
auto dim_analytical = TF_Dim(analytical_tensor, j);
ASSERT_EQ(dims[j], dim_analytical);
num_elem_analytical *= dim_analytical;
}
float* danalytical = new float[num_elem_analytical]{0};
memcpy(&danalytical[0], TF_TensorData(analytical_tensor),
TF_TensorByteSize(analytical_tensor));
for (int64_t j = 0; j < num_elem_analytical; j++) {
if (abs_error == 0) {
ASSERT_EQ(manuals[j], danalytical[j]);
} else {
ASSERT_NEAR(manuals[j], danalytical[j], abs_error);
}
}
TF_DeleteTensor(analytical_tensor);
delete[] danalytical;
}
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -26,6 +26,9 @@ void CompareNumericalAndAutodiffGradients(
absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
double abs_error = 1e-2);
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
absl::Span<const int64_t> dims, double abs_error = 1e-2);
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -29,6 +29,34 @@ namespace {
using tensorflow::TF_StatusPtr;
Status ReluModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Relu(ctx, inputs, outputs, "Relu");
}
Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("Relu", ReluRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "ReluGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status SparseSoftmaxCrossEntropyWithLogitsModel(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
@ -125,6 +153,40 @@ class CppGradients
bool UseFunction() const { return std::get<2>(GetParam()); }
};
TEST_P(CppGradients, TestReluGrad) {
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 10.0f, -1.0f};
int64_t X_dims[] = {3, 3};
AbstractTensorHandlePtr X;
{
AbstractTensorHandle* X_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), X_vals, X_dims, 2, &X_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
X.reset(X_raw);
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
ReluModel, ReluGradModel, ctx_.get(), {X.get()}, UseFunction()));
// Mathematically, Relu isn't differentiable at `0`. So `gradient_checker`
// does not work with it.
AbstractTensorHandlePtr Y;
{
AbstractTensorHandle* Y_raw;
Status s = TestScalarTensorHandle(ctx_.get(), 0.0f, &Y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
Y.reset(Y_raw);
}
std::vector<AbstractTensorHandle*> outputs(1);
auto s = RunModel(ReluGradModel, ctx_.get(), {Y.get()},
absl::MakeSpan(outputs), UseFunction());
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {},
/*abs_error*/ 0));
outputs[0]->Unref();
}
TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
if (UseFunction()) {
// TODO(b/168850692): Enable this.
@ -158,8 +220,7 @@ TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
SparseSoftmaxCrossEntropyWithLogitsModel,
SparseSoftmaxCrossEntropyWithLogitsGradModel, ctx_.get(),
{X.get(), Y.get()},
/*use_function=*/UseFunction()));
{X.get(), Y.get()}, UseFunction()));
}
TEST_P(CppGradients, TestBiasAddGrad) {
@ -192,7 +253,7 @@ TEST_P(CppGradients, TestBiasAddGrad) {
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()},
/*use_function=*/UseFunction()));
UseFunction()));
}
#ifdef PLATFORM_GOOGLE