add GetValue
and TestTensorHandleWithDimsInt
to unified_api_testutil
This commit is contained in:
parent
7c5ca02d9d
commit
2234086df0
tensorflow/c
eager
experimental/gradients
@ -248,6 +248,7 @@ cc_library(
|
||||
":c_api_unified_internal",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
|
@ -144,18 +144,43 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
}
|
||||
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||
int64* dims, int num_dims,
|
||||
int64_t* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestTensorHandleWithDimsFloat(
|
||||
eager_ctx, data, reinterpret_cast<int64_t*>(dims), num_dims);
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
|
||||
int64_t* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -54,8 +55,16 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
|
||||
// Get a Matrix TensorHandle with given float values and dimensions.
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||
int64* dims, int num_dims,
|
||||
int64_t* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
|
||||
int64_t* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Places data from `t` into *result_tensor.
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
|
||||
|
@ -86,16 +86,6 @@ Status ExpWithPassThroughGrad(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -128,7 +118,7 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
s = GetValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
|
Loading…
Reference in New Issue
Block a user