From ffc9bef49cd7b6781e189d8e090adbdea0c8d342 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 27 Aug 2020 14:16:58 -0700 Subject: [PATCH] Fix segfault in mnist_gradients_test and enable it on TAP. PiperOrigin-RevId: 328814584 Change-Id: Icfb075332005988d1aa1ee2cc52b4a6be5b94f47 --- tensorflow/c/eager/BUILD | 3 +-- tensorflow/c/eager/mnist_gradients_test.cc | 1 + tensorflow/c/eager/mnist_gradients_testutil.cc | 15 ++++++++++++--- tensorflow/c/eager/mnist_gradients_testutil.h | 12 ++++++++---- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index dc8b2c3f1f8..1a3b348e8f9 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -269,6 +269,7 @@ cc_library( "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:span", ], @@ -284,8 +285,6 @@ tf_cuda_cc_test( linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + [ "nomac", - "notap", # TODO(b/166150182): Enable - "no_oss", # TODO(b/166150182): Enable ], deps = [ ":abstract_tensor_handle", diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc index d6dd94806a7..116df2264ae 100644 --- a/tensorflow/c/eager/mnist_gradients_test.cc +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -762,6 +762,7 @@ TEST_P(CppGradients, TestMNIST_Training) { mnist_outputs[2]->Unref(); // release loss } +// TODO(b/166648529): Enable for mlir. #ifdef PLATFORM_GOOGLE INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc index 4b2c87c678d..9f5d0d149d4 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.cc +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -31,11 +31,15 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -using std::vector; -using tracing::TracingOperation; - // ========================== Tape Ops ============================== +namespace tensorflow { +namespace gradients { +namespace internal { + +using std::vector; +using tensorflow::tracing::TracingOperation; + // Computes `inputs[0] + inputs[1]` and records it on the tape. Status Add(AbstractContext* ctx, Tape* tape, absl::Span inputs, @@ -272,6 +276,7 @@ Status MNISTForwardModel(AbstractContext* ctx, AbstractTensorHandle* scores = temp_outputs[0]; + temp_outputs.resize(2); TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), "softmax_loss", registry)); // Compute Softmax(Scores,labels) @@ -592,3 +597,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { TFE_DeleteContextOptions(opts); return Status::OK(); } + +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h index b6de8ff6788..efe196e9ba3 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.h +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -27,13 +27,13 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" - -using namespace tensorflow; -using namespace tensorflow::gradients; -using namespace tensorflow::gradients::internal; +#include "tensorflow/core/platform/status.h" // ========================== Tape Ops ============================== +namespace tensorflow { +namespace gradients { +namespace internal { // Computes `inputs[0] + inputs[1]` and records it on the tape. Status Add(AbstractContext* ctx, Tape* tape, absl::Span inputs, @@ -144,3 +144,7 @@ Status RunModel(Model model, AbstractContext* ctx, const GradientRegistry& registry); Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow