Fix segfault in mnist_gradients_test and enable it on TAP.
PiperOrigin-RevId: 328814584 Change-Id: Icfb075332005988d1aa1ee2cc52b4a6be5b94f47
This commit is contained in:
parent
6e9d916229
commit
ffc9bef49c
@ -269,6 +269,7 @@ cc_library(
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
@ -284,8 +285,6 @@ tf_cuda_cc_test(
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"nomac",
|
||||
"notap", # TODO(b/166150182): Enable
|
||||
"no_oss", # TODO(b/166150182): Enable
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
|
@ -762,6 +762,7 @@ TEST_P(CppGradients, TestMNIST_Training) {
|
||||
mnist_outputs[2]->Unref(); // release loss
|
||||
}
|
||||
|
||||
// TODO(b/166648529): Enable for mlir.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
|
@ -31,11 +31,15 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
using std::vector;
|
||||
using tracing::TracingOperation;
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::tracing::TracingOperation;
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
@ -272,6 +276,7 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
|
||||
@ -592,3 +597,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -27,13 +27,13 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using namespace tensorflow::gradients;
|
||||
using namespace tensorflow::gradients::internal;
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
@ -144,3 +144,7 @@ Status RunModel(Model model, AbstractContext* ctx,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user