move TestReluGrad
to nn_grad_test
This commit is contained in:
parent
462a06442a
commit
4f19fb1826
@ -342,59 +342,6 @@ TEST_P(CppGradients, TestMatMulTranspose) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CppGradients, TestReluGrad) {
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
|
||||||
|
|
||||||
AbstractContextPtr ctx;
|
|
||||||
{
|
|
||||||
AbstractContext* ctx_raw = nullptr;
|
|
||||||
Status s =
|
|
||||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
|
||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
|
||||||
ctx.reset(ctx_raw);
|
|
||||||
}
|
|
||||||
|
|
||||||
// X = data
|
|
||||||
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
|
|
||||||
int64_t X_dims[] = {3, 3};
|
|
||||||
int num_dims = 2;
|
|
||||||
AbstractTensorHandlePtr X =
|
|
||||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
|
||||||
|
|
||||||
GradientRegistry registry;
|
|
||||||
Status s = RegisterGradients(®istry);
|
|
||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
|
||||||
|
|
||||||
/* Pseudo-code:
|
|
||||||
*
|
|
||||||
* tape.watch(X)
|
|
||||||
* Y = Relu(X)
|
|
||||||
* outputs = tape.gradient(Y, [X])
|
|
||||||
*/
|
|
||||||
std::vector<AbstractTensorHandle*> outputs(1);
|
|
||||||
s = RunModel(ReluGradModel, ctx.get(), {X.get()}, absl::MakeSpan(outputs),
|
|
||||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
|
||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
|
||||||
|
|
||||||
TF_Tensor* dX_tensor;
|
|
||||||
s = GetValue(outputs[0], &dX_tensor);
|
|
||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
|
||||||
|
|
||||||
float result_data[9] = {0};
|
|
||||||
memcpy(&result_data[0], TF_TensorData(dX_tensor),
|
|
||||||
TF_TensorByteSize(dX_tensor));
|
|
||||||
|
|
||||||
float expected_dX[9] = {1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
|
|
||||||
float tolerance = 1e-3;
|
|
||||||
for (int j = 0; j < 9; j++) {
|
|
||||||
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs[0]->Unref();
|
|
||||||
TF_DeleteTensor(dX_tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_P(CppGradients, TestMNISTGrad) {
|
TEST_P(CppGradients, TestMNISTGrad) {
|
||||||
bool use_function = !std::get<2>(GetParam());
|
bool use_function = !std::get<2>(GetParam());
|
||||||
if (use_function) {
|
if (use_function) {
|
||||||
|
@ -159,30 +159,6 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ReluGradModel(AbstractContext* ctx,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
|
||||||
tape->Watch(inputs[0]); // Watch X
|
|
||||||
vector<AbstractTensorHandle*> relu_outputs(1);
|
|
||||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
|
||||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
|
||||||
absl::MakeSpan(relu_outputs),
|
|
||||||
"relu0")); // Relu(X)
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/relu_outputs,
|
|
||||||
/*sources=*/inputs,
|
|
||||||
/*output_gradients=*/{}, outputs));
|
|
||||||
|
|
||||||
for (auto relu_output : relu_outputs) {
|
|
||||||
relu_output->Unref();
|
|
||||||
}
|
|
||||||
|
|
||||||
delete tape;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MNISTGradModel(AbstractContext* ctx,
|
Status MNISTGradModel(AbstractContext* ctx,
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
absl::Span<AbstractTensorHandle*> outputs,
|
absl::Span<AbstractTensorHandle*> outputs,
|
||||||
|
@ -61,12 +61,6 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
|||||||
absl::Span<AbstractTensorHandle*> outputs,
|
absl::Span<AbstractTensorHandle*> outputs,
|
||||||
const GradientRegistry& registry);
|
const GradientRegistry& registry);
|
||||||
|
|
||||||
// Test Model to verify ReluGrad functionality
|
|
||||||
Status ReluGradModel(AbstractContext* ctx,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Test Model to verify Multi-grad functionality for MNIST
|
// Test Model to verify Multi-grad functionality for MNIST
|
||||||
Status MNISTGradModel(AbstractContext* ctx,
|
Status MNISTGradModel(AbstractContext* ctx,
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
@ -74,6 +74,39 @@ void CompareNumericalAndAutodiffGradients(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CompareManualAndAutodiffGradients(
|
||||||
|
Model grad_model, AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<const float> manuals, bool use_function, double abs_error) {
|
||||||
|
auto num_inputs = inputs.size();
|
||||||
|
std::vector<AbstractTensorHandle*> outputs(num_inputs);
|
||||||
|
auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs),
|
||||||
|
/*use_function=*/use_function);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
int current_index_manual = 0;
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
if (!outputs[i]) continue;
|
||||||
|
|
||||||
|
TF_Tensor* analytical_tensor;
|
||||||
|
s = GetValue(outputs[i], &analytical_tensor);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
auto num_elem_analytical = TF_TensorElementCount(analytical_tensor);
|
||||||
|
|
||||||
|
float* danalytical = new float[num_elem_analytical]{0};
|
||||||
|
memcpy(&danalytical[0], TF_TensorData(analytical_tensor),
|
||||||
|
TF_TensorByteSize(analytical_tensor));
|
||||||
|
|
||||||
|
for (int j = 0; j < num_elem_analytical; j++) {
|
||||||
|
ASSERT_NEAR(manuals[current_index_manual], danalytical[j], abs_error);
|
||||||
|
++current_index_manual;
|
||||||
|
}
|
||||||
|
TF_DeleteTensor(analytical_tensor);
|
||||||
|
delete[] danalytical;
|
||||||
|
outputs[i]->Unref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -26,6 +26,15 @@ void CompareNumericalAndAutodiffGradients(
|
|||||||
absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
|
absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
|
||||||
double abs_error = 1e-2);
|
double abs_error = 1e-2);
|
||||||
|
|
||||||
|
// `manuals` should be a flat array of expected results of `grad_model`. e.g if
|
||||||
|
// `grad_model` output is `[[1, 2], nullptr, [3, 4]]`, `manuals` will be `[1,
|
||||||
|
// 2, 3, 4]`.
|
||||||
|
void CompareManualAndAutodiffGradients(
|
||||||
|
Model grad_model, AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<const float> manuals, bool use_function,
|
||||||
|
double abs_error = 1e-2);
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -29,6 +29,28 @@ namespace {
|
|||||||
|
|
||||||
using tensorflow::TF_StatusPtr;
|
using tensorflow::TF_StatusPtr;
|
||||||
|
|
||||||
|
Status ReluGradModel(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs) {
|
||||||
|
GradientRegistry registry;
|
||||||
|
TF_RETURN_IF_ERROR(registry.Register("Relu", ReluRegisterer));
|
||||||
|
|
||||||
|
Tape tape(/*persistent=*/false);
|
||||||
|
tape.Watch(inputs[0]);
|
||||||
|
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||||
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
||||||
|
absl::MakeSpan(temp_outputs), "ReluGrad"));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||||
|
/*sources=*/inputs,
|
||||||
|
/*output_gradients=*/{}, outputs));
|
||||||
|
for (auto temp_output : temp_outputs) {
|
||||||
|
temp_output->Unref();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status SparseSoftmaxCrossEntropyWithLogitsModel(
|
Status SparseSoftmaxCrossEntropyWithLogitsModel(
|
||||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
absl::Span<AbstractTensorHandle*> outputs) {
|
absl::Span<AbstractTensorHandle*> outputs) {
|
||||||
@ -125,6 +147,25 @@ class CppGradients
|
|||||||
bool UseFunction() const { return std::get<2>(GetParam()); }
|
bool UseFunction() const { return std::get<2>(GetParam()); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TEST_P(CppGradients, TestReluGrad) {
|
||||||
|
// Mathematically, Relu isn't differentiable at `0`. So `gradient_checker`
|
||||||
|
// does not work with it.
|
||||||
|
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
|
||||||
|
int64_t X_dims[] = {2, 2};
|
||||||
|
AbstractTensorHandlePtr X;
|
||||||
|
{
|
||||||
|
AbstractTensorHandle* X_raw;
|
||||||
|
Status s =
|
||||||
|
TestTensorHandleWithDimsFloat(ctx_.get(), X_vals, X_dims, 2, &X_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
X.reset(X_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_NO_FATAL_FAILURE(CompareManualAndAutodiffGradients(
|
||||||
|
ReluGradModel, ctx_.get(), {X.get()},
|
||||||
|
{1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, UseFunction()));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
|
TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
|
||||||
if (UseFunction()) {
|
if (UseFunction()) {
|
||||||
// TODO(b/168850692): Enable this.
|
// TODO(b/168850692): Enable this.
|
||||||
@ -158,8 +199,7 @@ TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
|
|||||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||||
SparseSoftmaxCrossEntropyWithLogitsModel,
|
SparseSoftmaxCrossEntropyWithLogitsModel,
|
||||||
SparseSoftmaxCrossEntropyWithLogitsGradModel, ctx_.get(),
|
SparseSoftmaxCrossEntropyWithLogitsGradModel, ctx_.get(),
|
||||||
{X.get(), Y.get()},
|
{X.get(), Y.get()}, UseFunction()));
|
||||||
/*use_function=*/UseFunction()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CppGradients, TestBiasAddGrad) {
|
TEST_P(CppGradients, TestBiasAddGrad) {
|
||||||
@ -192,7 +232,7 @@ TEST_P(CppGradients, TestBiasAddGrad) {
|
|||||||
|
|
||||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||||
BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()},
|
BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()},
|
||||||
/*use_function=*/UseFunction()));
|
UseFunction()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef PLATFORM_GOOGLE
|
#ifdef PLATFORM_GOOGLE
|
||||||
|
Loading…
x
Reference in New Issue
Block a user