Fix segfault in mnist_gradients_test and enable it on TAP.

PiperOrigin-RevId: 328814584
Change-Id: Icfb075332005988d1aa1ee2cc52b4a6be5b94f47
This commit is contained in:
Saurabh Saxena 2020-08-27 14:16:58 -07:00 committed by TensorFlower Gardener
parent 6e9d916229
commit ffc9bef49c
4 changed files with 22 additions and 9 deletions

View File

@ -269,6 +269,7 @@ cc_library(
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:span",
],
@ -284,8 +285,6 @@ tf_cuda_cc_test(
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [
"nomac",
"notap", # TODO(b/166150182): Enable
"no_oss", # TODO(b/166150182): Enable
],
deps = [
":abstract_tensor_handle",

View File

@ -762,6 +762,7 @@ TEST_P(CppGradients, TestMNIST_Training) {
mnist_outputs[2]->Unref(); // release loss
}
// TODO(b/166648529): Enable for mlir.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,

View File

@ -31,11 +31,15 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
using std::vector;
using tracing::TracingOperation;
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
using std::vector;
using tensorflow::tracing::TracingOperation;
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
@ -272,6 +276,7 @@ Status MNISTForwardModel(AbstractContext* ctx,
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
@ -592,3 +597,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
TFE_DeleteContextOptions(opts);
return Status::OK();
}
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -27,13 +27,13 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
using namespace tensorflow;
using namespace tensorflow::gradients;
using namespace tensorflow::gradients::internal;
#include "tensorflow/core/platform/status.h"
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
@ -144,3 +144,7 @@ Status RunModel(Model model, AbstractContext* ctx,
const GradientRegistry& registry);
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace internal
} // namespace gradients
} // namespace tensorflow