updating rebase
This commit is contained in:
parent
083e5f48fd
commit
0926d6f238
@ -428,6 +428,32 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients_util",
|
||||
srcs = [
|
||||
"mnist_gradients_util.cc",
|
||||
"mnist_gradients_util.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "mnist_gradients_test",
|
||||
size = "small",
|
||||
@ -444,6 +470,7 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":mnist_gradients_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
@ -460,6 +460,9 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
s = MatMulGradModel(ctx.get(), {A.get(), B.get()}, absl::MakeSpan(outputs), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// s = MatMulGradModel(ctx.get(), {A.get(), B.get()}, absl::MakeSpan(outputs), registry);
|
||||
// ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dA_tensor;
|
||||
s = getValue(outputs[0], &dA_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
@ -584,13 +587,9 @@ TEST_P(CppGradients, TestMNISTForward2) {
|
||||
|
||||
// Run the Forward Pass
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
<<<<<<< HEAD
|
||||
Status s = RunModel(MNISTForwardModel, ctx.get(), {X.get(), W1.get(), W2.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
=======
|
||||
Status s = MNISTForwardModel(ctx.get(), {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), registry);
|
||||
>>>>>>> 37eefa1df8... Adding tests for matmul grad, memory error
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Verify the Results
|
||||
@ -601,10 +600,6 @@ TEST_P(CppGradients, TestMNISTForward2) {
|
||||
float result_data[6] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(scores_tensor), TF_TensorByteSize(scores_tensor));
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
//float expected_scores [6] = {0f, 12.0f, -1.0f, -17.0f, 16.8f, -28.0f};
|
||||
>>>>>>> 37eefa1df8... Adding tests for matmul grad, memory error
|
||||
float expected_scores [6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f};
|
||||
float tolerance = 1e-3;
|
||||
for(int j = 0; j < 6; j++){
|
||||
@ -638,10 +633,7 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
<<<<<<< HEAD
|
||||
tape->Watch(ToId(X));
|
||||
=======
|
||||
>>>>>>> 37eefa1df8... Adding tests for matmul grad, memory error
|
||||
tape->Watch(ToId(W1)); // Watch W1.
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
@ -682,15 +674,11 @@ TEST_P(CppGradients, TestMatMulTranspose) {
|
||||
|
||||
// Run the MatMul Op
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
<<<<<<< HEAD
|
||||
|
||||
Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
|
||||
=======
|
||||
Status s = MatMulTransposeModel(ctx.get(), {X.get(), W1.get()}, absl::MakeSpan(outputs), registry);
|
||||
>>>>>>> 37eefa1df8... Adding tests for matmul grad, memory error
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Verify the Results
|
||||
@ -701,10 +689,6 @@ TEST_P(CppGradients, TestMatMulTranspose) {
|
||||
float result_data[6] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(scores_tensor), TF_TensorByteSize(scores_tensor));
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
|
||||
>>>>>>> 37eefa1df8... Adding tests for matmul grad, memory error
|
||||
float expected_scores [6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
|
||||
float tolerance = 1e-3;
|
||||
|
||||
@ -714,7 +698,6 @@ TEST_P(CppGradients, TestMatMulTranspose) {
|
||||
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
// Test Model to verify ReluGrad functionality
|
||||
Status ReluGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
@ -927,8 +910,6 @@ TEST_P(CppGradients, TestSoftmaxLossGrad) {
|
||||
TF_DeleteTensor(dX_tensor);
|
||||
}
|
||||
|
||||
=======
|
||||
>>>>>>> 37eefa1df8... Adding tests for matmul grad, memory error
|
||||
|
||||
// TODO(b/160888630): Enable this test with mlir after AddInputList is
|
||||
// supported. It is needed for AddN op which is used for gradient aggregation.
|
||||
|
Loading…
Reference in New Issue
Block a user