STT-tensorflow/tensorflow/c/eager/tracing_utils.cc
Saurabh Saxena 7ec70d54b2 - Integrate C++ tape with op building APIs via TapeContext and TapeOperation which delegate calls to a parent execution context and record operations on the tape. Please see gradients_test.cc for usage.
- This will replace the helper functions in gradients_internal.h. I will clean that up in a followup CL.
- Also drop ForwardOperation::ctx since that is unused right now. We can add it later if we need.

PiperOrigin-RevId: 333390787
Change-Id: I80f2c460a9538a1a14ed1497c59f7b37a633a633
2020-09-23 16:12:39 -07:00

38 lines
1.4 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/tracing_utils.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace tracing {
Status MaybeSetOpName(AbstractOperation* op, const char* op_name) {
if (isa<TracingOperation>(op)) {
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(op)->SetOpName(op_name));
}
if (isa<gradients::TapeOperation>(op)) {
TF_RETURN_IF_ERROR(MaybeSetOpName(
dyn_cast<gradients::TapeOperation>(op)->GetBackingOperation(),
op_name));
}
return Status::OK();
}
} // namespace tracing
} // namespace tensorflow