add GetValue and TestTensorHandleWithDimsInt to unified_api_testutil

This commit is contained in:
Võ Văn Nghĩa 2020-12-10 00:25:36 +07:00
parent 7c5ca02d9d
commit 2234086df0
4 changed files with 40 additions and 15 deletions

View File

@ -248,6 +248,7 @@ cc_library(
":c_api_unified_internal",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",

View File

@ -144,18 +144,43 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value,
}
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
int64* dims, int num_dims,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestTensorHandleWithDimsFloat(
eager_ctx, data, reinterpret_cast<int64_t*>(dims), num_dims);
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return StatusFromTF_Status(status.get());
}
} // namespace tensorflow

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
@ -54,8 +55,16 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value,
// Get a Matrix TensorHandle with given float values and dimensions.
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
int64* dims, int num_dims,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor);
// Places data from `t` into *result_tensor.
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_

View File

@ -86,16 +86,6 @@ Status ExpWithPassThroughGrad(AbstractContext* ctx,
return Status::OK();
}
Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return Status::OK();
}
TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -128,7 +118,7 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
s = GetValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);