From b32ec281aa66f6bf3e32394c2f2eed6e523c966b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 22 Nov 2020 20:35:25 -0800 Subject: [PATCH] Speed up Shape creation by avoiding unnecessary validations. Before: BM_MakeShape 73.1ns +- 0% After: BM_MakeShape 11.2ns +- 1% PiperOrigin-RevId: 343779785 Change-Id: Iedc1ccc784cbbc73ad310a6009e4125601da68da --- tensorflow/compiler/xla/BUILD | 1 + .../compiler/xla/client/xla_builder_test.cc | 3 +- .../xla_client_backend_independent_test.py | 3 +- tensorflow/compiler/xla/shape_util.cc | 94 +++++++++++++++++-- tensorflow/compiler/xla/shape_util.h | 5 + tensorflow/compiler/xla/shape_util_test.cc | 15 +++ tensorflow/compiler/xla/xla_data.proto | 1 + 7 files changed, 110 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 9883a3f47c5..d44f9991936 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -361,6 +361,7 @@ tf_cc_test( ":util", ":xla_data_proto_cc", "//tensorflow/core:lib", + "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 947850cb049..cd21c6dc414 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -338,8 +338,7 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) { /*broadcast_dimensions=*/{0, 1, 2}); auto statusor = BuildHloModule(&b); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("shape's dimensions must not be < 0")); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape")); } TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { diff --git a/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py index 180bb040cc4..eb9c90941b6 100644 --- a/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py +++ b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py @@ -38,8 +38,7 @@ ops = xla_client.ops class ShapeTest(absltest.TestCase): def testInvalidShapes(self): - with self.assertRaisesRegex(RuntimeError, - "shape's dimensions must not be < 0.*"): + with self.assertRaisesRegex(RuntimeError, "invalid shape"): xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4]) with self.assertRaisesRegex( diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index cb0edfb6be6..e84a2591707 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -52,6 +52,33 @@ namespace xla { using absl::StrAppend; using absl::StrCat; +namespace { +// An array that is indexed by PrimitiveType, and returns +// the size of each element of that primitive type, or 0 +// if the PrimitiveType is not a primitive type +constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = { + 0, // PRIMITIVE_TYPE_INVALID = 0, + sizeof(int8), // PRED = 1 + sizeof(int8), // S8 = 2 + sizeof(int16), // S16 = 3 + sizeof(int32), // S32 = 4 + sizeof(int64), // S64 = 5 + sizeof(uint8), // U8 = 6 + sizeof(uint16), // U16 = 7 + sizeof(uint32), // U32 = 8 + sizeof(uint64), // U64 = 9 + sizeof(float) / 2, // F16 = 10 + sizeof(float), // F32 = 11 + sizeof(double), // F64 = 12 + 0, // TUPLE = 13 + 0, // OPAQUE_TYPE = 14 + sizeof(complex64), // C64 = 15 + sizeof(float) / 2, // BF16 = 16 + 0, // TOKEN = 17 + sizeof(complex128) // C128 = 18 +}; +} // namespace + string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); } string ShapeIndexView::ToString() const { @@ -175,6 +202,42 @@ StatusOr MakeShapeWithLayoutInternal( return accum; } +/* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type, + absl::Span dimensions, + Shape* shape) { + const int eint = static_cast(element_type); + int64 dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE) + ? primitive_byte_size[eint] + : 0); // Out of range: force a failure + if (dense_shape_size <= 0) { + return false; + } + + // Verify that array-based lookup is consistent with public API. + DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type)) + << element_type; + + shape->set_element_type(element_type); + const int ndims = dimensions.size(); + auto layout = shape->mutable_layout(); + layout->set_format(DENSE); + auto* minor_to_major = layout->mutable_minor_to_major(); + for (int i = 0; i < ndims; i++) { + const int64 d = dimensions[i]; + if (d < 0) { + return false; + } + dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d); + if (dense_shape_size < 0) { + return false; + } + + shape->add_dimensions(d); + minor_to_major->push_back(ndims - 1 - i); + } + return true; +} + /* static */ ProgramShape ShapeUtil::MakeProgramShape( std::initializer_list parameters, Shape result) { ProgramShape program_shape; @@ -187,7 +250,9 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, absl::Span dimensions) { - return MakeValidatedShape(element_type, dimensions).ValueOrDie(); + Shape shape; + CHECK(FillNewShape(element_type, dimensions, &shape)); + return shape; } /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) { @@ -210,18 +275,31 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions) { - CHECK(IsArrayPrimitiveType(element_type)) << element_type; - Shape result; - TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result)); - return result; + Shape shape; + if (!FillNewShape(element_type, dimensions, &shape)) { + return InvalidArgument("invalid shape type=%d, dims=[%s]", + static_cast(element_type), + absl::StrJoin(dimensions, ",")); + } + return shape; } /* static */ StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions) { - TF_ASSIGN_OR_RETURN(Shape shape, - MakeValidatedShape(element_type, dimensions)); - for (int i = 0; i < dynamic_dimensions.size(); ++i) { + if (dynamic_dimensions.size() != dimensions.size()) { + return InvalidArgument( + "dynamic dimensions size %d did not match number of dimensions %d", + dynamic_dimensions.size(), dimensions.size()); + } + + Shape shape; + if (!FillNewShape(element_type, dimensions, &shape)) { + return InvalidArgument("invalid shape type=%d, dims=[%s]", + static_cast(element_type), + absl::StrJoin(dimensions, ",")); + } + for (int i = 0, n = dimensions.size(); i < n; i++) { shape.set_dynamic_dimension(i, dynamic_dimensions[i]); } return shape; diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index c1a6a2c8b1d..ff47ab6ea80 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -792,6 +792,11 @@ class ShapeUtil { static bool CanUpcastIntegral(const Shape& from, const Shape& to); private: + // Fills *shape. Returns true on success. + // REQUIRES: *shape is empty. + static bool FillNewShape(PrimitiveType element_type, + absl::Span dimensions, Shape* shape); + // Validates the shape size is sane. This makes sure it's safe to do // calculations in int64 without overflowing. static Status ValidateShapeSize(const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4e2030667ee..1a944d01941 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace xla { namespace { @@ -827,5 +828,19 @@ TEST(AlignmentTest, EXPECT_FALSE(aligned_shape); } +void BM_MakeShape(::testing::benchmark::State& state) { + for (auto s : state) { + ShapeUtil::MakeShape(F32, {2}); + } +} +BENCHMARK(BM_MakeShape); + +void BM_MakeValidatedShape(::testing::benchmark::State& state) { + for (auto s : state) { + ShapeUtil::MakeValidatedShape(F32, {2}).ValueOrDie(); + } +} +BENCHMARK(BM_MakeValidatedShape); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index eade7c2426d..01de56bf85d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -85,6 +85,7 @@ enum PrimitiveType { // Next = 19 } // LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc // )