diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index f9d0f667006..99c4a213d2d 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -344,6 +344,7 @@ tf_cc_test( ":xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:test_main", + "//tensorflow/core/platform:test_benchmark", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc index 47680a6ba32..1094cdb918f 100644 --- a/tensorflow/compiler/xla/shape_test.cc +++ b/tensorflow/compiler/xla/shape_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace xla { namespace { @@ -218,5 +219,34 @@ TEST_F(ShapeTest, SupportsAbslHash) { nested_tuple_, dynamic_matrix_})); } +void BM_ShapeCopy(::testing::benchmark::State& state) { + // Create different shapes based on benchmark parameters: + Shape shape; + switch (state.range(0)) { + case 0: { + // Shape() + break; + } + case 1: { + // f32[1,2,2]{2,1,0} + shape = Shape(F32, {1, 2, 2}, {false, false, false}, {}); + *shape.mutable_layout() = Layout({2, 1, 0}); + break; + } + case 2: { + // f32[1,2,2]{2,1,0:T(2,128)} + shape = Shape(F32, {1, 2, 2}, {false, false, false}, {}); + *shape.mutable_layout() = Layout({2, 1, 0}, {Tile({2, 128})}); + break; + } + } + state.SetLabel(shape.ToString(true)); + + for (auto s : state) { + Shape copy(shape); + } +} +BENCHMARK(BM_ShapeCopy)->Arg(0)->Arg(1)->Arg(2); + } // namespace } // namespace xla