diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index cfd12ebf6c1..bc4c29852f1 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -248,6 +248,7 @@ cc_library( ":c_api_unified_internal", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_tensor", "//tensorflow/core:framework", "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:errors", diff --git a/tensorflow/c/eager/unified_api_testutil.cc b/tensorflow/c/eager/unified_api_testutil.cc index 9e8683df0e7..9907fc28c3d 100644 --- a/tensorflow/c/eager/unified_api_testutil.cc +++ b/tensorflow/c/eager/unified_api_testutil.cc @@ -144,18 +144,43 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value, } Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data, - int64* dims, int num_dims, + int64_t* dims, int num_dims, AbstractTensorHandle** tensor) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = TestTensorHandleWithDimsFloat( - eager_ctx, data, reinterpret_cast(dims), num_dims); + TFE_TensorHandle* input_eager = + TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims); *tensor = unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); return Status::OK(); } +Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data, + int64_t* dims, int num_dims, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = + TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return Status::OK(); +} + +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + return StatusFromTF_Status(status.get()); +} + } // namespace tensorflow diff --git a/tensorflow/c/eager/unified_api_testutil.h b/tensorflow/c/eager/unified_api_testutil.h index eb8d0ffa897..39bf553efa5 100644 --- a/tensorflow/c/eager/unified_api_testutil.h +++ b/tensorflow/c/eager/unified_api_testutil.h @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { @@ -54,8 +55,16 @@ Status TestScalarTensorHandle(AbstractContext* ctx, float value, // Get a Matrix TensorHandle with given float values and dimensions. Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data, - int64* dims, int num_dims, + int64_t* dims, int num_dims, AbstractTensorHandle** tensor); + +// Get a TensorHandle with given int values and dimensions +Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data, + int64_t* dims, int num_dims, + AbstractTensorHandle** tensor); + +// Places data from `t` into *result_tensor. +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor); } // namespace tensorflow #endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_ diff --git a/tensorflow/c/experimental/gradients/custom_gradient_test.cc b/tensorflow/c/experimental/gradients/custom_gradient_test.cc index 9ca018763c8..16fb3394dd8 100644 --- a/tensorflow/c/experimental/gradients/custom_gradient_test.cc +++ b/tensorflow/c/experimental/gradients/custom_gradient_test.cc @@ -86,16 +86,6 @@ Status ExpWithPassThroughGrad(AbstractContext* ctx, return Status::OK(); } -Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_TensorHandle* result_t = - TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); - return Status::OK(); -} - TEST_P(CustomGradientTest, ExpWithPassThroughGrad) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -128,7 +118,7 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) { ASSERT_EQ(errors::OK, s.code()) << s.error_message(); TF_Tensor* result_tensor; - s = getValue(outputs[0], &result_tensor); + s = GetValue(outputs[0], &result_tensor); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); auto result_value = static_cast(TF_TensorData(result_tensor)); EXPECT_EQ(*result_value, 1.0);