Speed up Shape creation by avoiding unnecessary validations.

Before:	BM_MakeShape  73.1ns +- 0%
After:  BM_MakeShape   11.2ns +- 1%
PiperOrigin-RevId: 343779785
Change-Id: Iedc1ccc784cbbc73ad310a6009e4125601da68da
This commit is contained in:
A. Unique TensorFlower 2020-11-22 20:35:25 -08:00 committed by TensorFlower Gardener
parent e8474b8065
commit b32ec281aa
7 changed files with 110 additions and 12 deletions

View File

@ -361,6 +361,7 @@ tf_cc_test(
":util",
":xla_data_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],

View File

@ -338,8 +338,7 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) {
/*broadcast_dimensions=*/{0, 1, 2});
auto statusor = BuildHloModule(&b);
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("shape's dimensions must not be < 0"));
EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape"));
}
TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {

View File

@ -38,8 +38,7 @@ ops = xla_client.ops
class ShapeTest(absltest.TestCase):
def testInvalidShapes(self):
with self.assertRaisesRegex(RuntimeError,
"shape's dimensions must not be < 0.*"):
with self.assertRaisesRegex(RuntimeError, "invalid shape"):
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4])
with self.assertRaisesRegex(

View File

@ -52,6 +52,33 @@ namespace xla {
using absl::StrAppend;
using absl::StrCat;
namespace {
// An array that is indexed by PrimitiveType, and returns
// the size of each element of that primitive type, or 0
// if the PrimitiveType is not a primitive type
constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
0, // PRIMITIVE_TYPE_INVALID = 0,
sizeof(int8), // PRED = 1
sizeof(int8), // S8 = 2
sizeof(int16), // S16 = 3
sizeof(int32), // S32 = 4
sizeof(int64), // S64 = 5
sizeof(uint8), // U8 = 6
sizeof(uint16), // U16 = 7
sizeof(uint32), // U32 = 8
sizeof(uint64), // U64 = 9
sizeof(float) / 2, // F16 = 10
sizeof(float), // F32 = 11
sizeof(double), // F64 = 12
0, // TUPLE = 13
0, // OPAQUE_TYPE = 14
sizeof(complex64), // C64 = 15
sizeof(float) / 2, // BF16 = 16
0, // TOKEN = 17
sizeof(complex128) // C128 = 18
};
} // namespace
string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
string ShapeIndexView::ToString() const {
@ -175,6 +202,42 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return accum;
}
/* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
absl::Span<const int64> dimensions,
Shape* shape) {
const int eint = static_cast<int>(element_type);
int64 dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
? primitive_byte_size[eint]
: 0); // Out of range: force a failure
if (dense_shape_size <= 0) {
return false;
}
// Verify that array-based lookup is consistent with public API.
DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
<< element_type;
shape->set_element_type(element_type);
const int ndims = dimensions.size();
auto layout = shape->mutable_layout();
layout->set_format(DENSE);
auto* minor_to_major = layout->mutable_minor_to_major();
for (int i = 0; i < ndims; i++) {
const int64 d = dimensions[i];
if (d < 0) {
return false;
}
dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
if (dense_shape_size < 0) {
return false;
}
shape->add_dimensions(d);
minor_to_major->push_back(ndims - 1 - i);
}
return true;
}
/* static */ ProgramShape ShapeUtil::MakeProgramShape(
std::initializer_list<Shape> parameters, Shape result) {
ProgramShape program_shape;
@ -187,7 +250,9 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
absl::Span<const int64> dimensions) {
return MakeValidatedShape(element_type, dimensions).ValueOrDie();
Shape shape;
CHECK(FillNewShape(element_type, dimensions, &shape));
return shape;
}
/* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
@ -210,18 +275,31 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
PrimitiveType element_type, absl::Span<const int64> dimensions) {
CHECK(IsArrayPrimitiveType(element_type)) << element_type;
Shape result;
TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result));
return result;
Shape shape;
if (!FillNewShape(element_type, dimensions, &shape)) {
return InvalidArgument("invalid shape type=%d, dims=[%s]",
static_cast<int>(element_type),
absl::StrJoin(dimensions, ","));
}
return shape;
}
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
PrimitiveType element_type, absl::Span<const int64> dimensions,
const std::vector<bool>& dynamic_dimensions) {
TF_ASSIGN_OR_RETURN(Shape shape,
MakeValidatedShape(element_type, dimensions));
for (int i = 0; i < dynamic_dimensions.size(); ++i) {
if (dynamic_dimensions.size() != dimensions.size()) {
return InvalidArgument(
"dynamic dimensions size %d did not match number of dimensions %d",
dynamic_dimensions.size(), dimensions.size());
}
Shape shape;
if (!FillNewShape(element_type, dimensions, &shape)) {
return InvalidArgument("invalid shape type=%d, dims=[%s]",
static_cast<int>(element_type),
absl::StrJoin(dimensions, ","));
}
for (int i = 0, n = dimensions.size(); i < n; i++) {
shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
}
return shape;

View File

@ -792,6 +792,11 @@ class ShapeUtil {
static bool CanUpcastIntegral(const Shape& from, const Shape& to);
private:
// Fills *shape. Returns true on success.
// REQUIRES: *shape is empty.
static bool FillNewShape(PrimitiveType element_type,
absl::Span<const int64> dimensions, Shape* shape);
// Validates the shape size is sane. This makes sure it's safe to do
// calculations in int64 without overflowing.
static Status ValidateShapeSize(const Shape& shape);

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace xla {
namespace {
@ -827,5 +828,19 @@ TEST(AlignmentTest,
EXPECT_FALSE(aligned_shape);
}
void BM_MakeShape(::testing::benchmark::State& state) {
for (auto s : state) {
ShapeUtil::MakeShape(F32, {2});
}
}
BENCHMARK(BM_MakeShape);
void BM_MakeValidatedShape(::testing::benchmark::State& state) {
for (auto s : state) {
ShapeUtil::MakeValidatedShape(F32, {2}).ValueOrDie();
}
}
BENCHMARK(BM_MakeValidatedShape);
} // namespace
} // namespace xla

View File

@ -85,6 +85,7 @@ enum PrimitiveType {
// Next = 19
}
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,
// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc
// )