Fix failures in eager/c_api_test benchmarks.

PiperOrigin-RevId: 357243746
Change-Id: I242d7900f0ed50e984f56050b1c44f384ed40dda
This commit is contained in:
Xiao Yu 2021-02-12 12:20:02 -08:00 committed by TensorFlower Gardener
parent 048e5baef9
commit e3cf51e3e2

View File

@ -1168,16 +1168,17 @@ void BM_ExecuteFunction(int iters, int async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(matmul, &retval[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
}
if (async) {
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
@ -1249,16 +1250,15 @@ void BM_ReadVariable(int iters) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 1;
TFE_TensorHandle* h = nullptr;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(op, &h, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK_EQ(1, num_retvals);
@ -1267,11 +1267,9 @@ void BM_ReadVariable(int iters) {
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
h = nullptr;
TFE_OpAddInput(op, var_handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
}
tensorflow::testing::StopTiming();
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);