add GetValue
and TestTensorHandleWithDimsInt
to unified_api_testutil
This commit is contained in:
parent
7c5ca02d9d
commit
2234086df0
@ -248,6 +248,7 @@ cc_library(
|
|||||||
":c_api_unified_internal",
|
":c_api_unified_internal",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
"//tensorflow/c:tf_tensor",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core/platform:errors",
|
||||||
|
@ -144,18 +144,43 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||||
int64* dims, int num_dims,
|
int64_t* dims, int num_dims,
|
||||||
AbstractTensorHandle** tensor) {
|
AbstractTensorHandle** tensor) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_Context* eager_ctx =
|
TFE_Context* eager_ctx =
|
||||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||||
TFE_TensorHandle* input_eager = TestTensorHandleWithDimsFloat(
|
TFE_TensorHandle* input_eager =
|
||||||
eager_ctx, data, reinterpret_cast<int64_t*>(dims), num_dims);
|
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||||
*tensor =
|
*tensor =
|
||||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
|
||||||
|
int64_t* dims, int num_dims,
|
||||||
|
AbstractTensorHandle** tensor) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_Context* eager_ctx =
|
||||||
|
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||||
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||||
|
TFE_TensorHandle* input_eager =
|
||||||
|
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||||
|
*tensor =
|
||||||
|
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_TensorHandle* result_t =
|
||||||
|
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||||
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||||
|
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||||
|
return StatusFromTF_Status(status.get());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/eager/abstract_context.h"
|
#include "tensorflow/c/eager/abstract_context.h"
|
||||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -54,8 +55,16 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
|||||||
|
|
||||||
// Get a Matrix TensorHandle with given float values and dimensions.
|
// Get a Matrix TensorHandle with given float values and dimensions.
|
||||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||||
int64* dims, int num_dims,
|
int64_t* dims, int num_dims,
|
||||||
AbstractTensorHandle** tensor);
|
AbstractTensorHandle** tensor);
|
||||||
|
|
||||||
|
// Get a TensorHandle with given int values and dimensions
|
||||||
|
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
|
||||||
|
int64_t* dims, int num_dims,
|
||||||
|
AbstractTensorHandle** tensor);
|
||||||
|
|
||||||
|
// Places data from `t` into *result_tensor.
|
||||||
|
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
|
#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
|
||||||
|
@ -86,16 +86,6 @@ Status ExpWithPassThroughGrad(AbstractContext* ctx,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
|
||||||
TFE_TensorHandle* result_t =
|
|
||||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
|
||||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
|
||||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
|
TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
@ -128,7 +118,7 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
|
|||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
TF_Tensor* result_tensor;
|
TF_Tensor* result_tensor;
|
||||||
s = getValue(outputs[0], &result_tensor);
|
s = GetValue(outputs[0], &result_tensor);
|
||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||||
EXPECT_EQ(*result_value, 1.0);
|
EXPECT_EQ(*result_value, 1.0);
|
||||||
|
Loading…
Reference in New Issue
Block a user