add a micro benchmarks for OpKernel::TraceString

PiperOrigin-RevId: 301605173
Change-Id: I9003d08220448f79a0144d637cd0fb00ad54f7b7
This commit is contained in:
A. Unique TensorFlower 2020-03-18 09:32:53 -07:00 committed by TensorFlower Gardener
parent 37dfcf6e2a
commit a259351964

View File

@ -1026,6 +1026,7 @@ void BM_InputRangeHelper(int iters, const NodeDef& node_def,
REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("MatMul").Device(DEVICE_CPU), DummyKernel);
void BM_ConcatInputRange(int iters) {
testing::StopTiming();
@ -1067,8 +1068,51 @@ void BM_SelectInputRange(int iters) {
BM_InputRangeHelper(iters, node_def, "condition", 0, 1);
}
void BM_TraceString(const int iters, const int verbose) {
testing::StopTiming();
// Create a MatMul NodeDef with 2 inputs.
NodeDef node_def;
node_def.set_name("gradient_tape/model_1/dense_1/MatMul_1");
node_def.set_op("MatMul");
AttrValue transpose_a, transpose_b, attr_t;
attr_t.set_type(DT_FLOAT);
node_def.mutable_attr()->insert({"T", attr_t});
transpose_a.set_b(true);
node_def.mutable_attr()->insert({"transpose_a", transpose_a});
transpose_b.set_b(true);
node_def.mutable_attr()->insert({"transpose_b", transpose_b});
for (size_t i = 0; i < 2; ++i) {
node_def.add_input(strings::StrCat("a:", i));
}
// Build OpKernel and OpKernelContext
Status status;
auto device = absl::make_unique<DummyDevice>(Env::Default());
std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
cpu_allocator(), node_def,
TF_GRAPH_DEF_VERSION, &status));
TF_CHECK_OK(status);
OpKernelContext::Params params;
params.device = device.get();
params.op_kernel = op.get();
Tensor a(DT_FLOAT, TensorShape({99000, 256}));
Tensor b(DT_FLOAT, TensorShape({256, 256}));
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b)};
params.inputs = &inputs;
auto ctx = absl::make_unique<OpKernelContext>(&params);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
auto trace = op->TraceString(ctx.get(), verbose);
}
testing::StopTiming();
}
BENCHMARK(BM_ConcatInputRange);
BENCHMARK(BM_SelectInputRange);
BENCHMARK(BM_TraceString)->Arg(1)->Arg(0);
TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) {
auto kernel_list = GetAllRegisteredKernels();