Fix failures in eager/c_api_test benchmarks.

PiperOrigin-RevId: 357243746
Change-Id: I242d7900f0ed50e984f56050b1c44f384ed40dda
This commit is contained in:
Xiao Yu 2021-02-12 12:20:02 -08:00 committed by TensorFlower Gardener
parent 048e5baef9
commit e3cf51e3e2

View File

@ -1168,16 +1168,17 @@ void BM_ExecuteFunction(int iters, int async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retval[1] = {nullptr}; TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1; int num_retvals = 1;
tensorflow::testing::StartTiming(); tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) { for (int i = 0; i < iters; ++i) {
TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(matmul, &retval[0], &num_retvals, status); TFE_Execute(matmul, &retval[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
} }
if (async) { if (async) {
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
@ -1249,16 +1250,15 @@ void BM_ReadVariable(int iters) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0); TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 1; int num_retvals = 1;
TFE_TensorHandle* h = nullptr; TFE_TensorHandle* h = nullptr;
tensorflow::testing::StartTiming(); tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) { for (int i = 0; i < iters; ++i) {
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(op, &h, &num_retvals, status); TFE_Execute(op, &h, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK_EQ(1, num_retvals); CHECK_EQ(1, num_retvals);
@ -1267,11 +1267,9 @@ void BM_ReadVariable(int iters) {
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
h = nullptr; h = nullptr;
TFE_OpAddInput(op, var_handle, status); TFE_DeleteOp(op);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
} }
tensorflow::testing::StopTiming(); tensorflow::testing::StopTiming();
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(var_handle); TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);