Speed up Shape creation by avoiding unnecessary validations.
Before: BM_MakeShape 73.1ns +- 0% After: BM_MakeShape 11.2ns +- 1% PiperOrigin-RevId: 343779785 Change-Id: Iedc1ccc784cbbc73ad310a6009e4125601da68da
This commit is contained in:
parent
e8474b8065
commit
b32ec281aa
@ -361,6 +361,7 @@ tf_cc_test(
|
||||
":util",
|
||||
":xla_data_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -338,8 +338,7 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) {
|
||||
/*broadcast_dimensions=*/{0, 1, 2});
|
||||
auto statusor = BuildHloModule(&b);
|
||||
ASSERT_FALSE(statusor.ok());
|
||||
EXPECT_THAT(statusor.status().error_message(),
|
||||
HasSubstr("shape's dimensions must not be < 0"));
|
||||
EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape"));
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
|
||||
|
@ -38,8 +38,7 @@ ops = xla_client.ops
|
||||
class ShapeTest(absltest.TestCase):
|
||||
|
||||
def testInvalidShapes(self):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"shape's dimensions must not be < 0.*"):
|
||||
with self.assertRaisesRegex(RuntimeError, "invalid shape"):
|
||||
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -52,6 +52,33 @@ namespace xla {
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
|
||||
namespace {
|
||||
// An array that is indexed by PrimitiveType, and returns
|
||||
// the size of each element of that primitive type, or 0
|
||||
// if the PrimitiveType is not a primitive type
|
||||
constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
|
||||
0, // PRIMITIVE_TYPE_INVALID = 0,
|
||||
sizeof(int8), // PRED = 1
|
||||
sizeof(int8), // S8 = 2
|
||||
sizeof(int16), // S16 = 3
|
||||
sizeof(int32), // S32 = 4
|
||||
sizeof(int64), // S64 = 5
|
||||
sizeof(uint8), // U8 = 6
|
||||
sizeof(uint16), // U16 = 7
|
||||
sizeof(uint32), // U32 = 8
|
||||
sizeof(uint64), // U64 = 9
|
||||
sizeof(float) / 2, // F16 = 10
|
||||
sizeof(float), // F32 = 11
|
||||
sizeof(double), // F64 = 12
|
||||
0, // TUPLE = 13
|
||||
0, // OPAQUE_TYPE = 14
|
||||
sizeof(complex64), // C64 = 15
|
||||
sizeof(float) / 2, // BF16 = 16
|
||||
0, // TOKEN = 17
|
||||
sizeof(complex128) // C128 = 18
|
||||
};
|
||||
} // namespace
|
||||
|
||||
string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
|
||||
|
||||
string ShapeIndexView::ToString() const {
|
||||
@ -175,6 +202,42 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
return accum;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
|
||||
absl::Span<const int64> dimensions,
|
||||
Shape* shape) {
|
||||
const int eint = static_cast<int>(element_type);
|
||||
int64 dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
|
||||
? primitive_byte_size[eint]
|
||||
: 0); // Out of range: force a failure
|
||||
if (dense_shape_size <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify that array-based lookup is consistent with public API.
|
||||
DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
|
||||
<< element_type;
|
||||
|
||||
shape->set_element_type(element_type);
|
||||
const int ndims = dimensions.size();
|
||||
auto layout = shape->mutable_layout();
|
||||
layout->set_format(DENSE);
|
||||
auto* minor_to_major = layout->mutable_minor_to_major();
|
||||
for (int i = 0; i < ndims; i++) {
|
||||
const int64 d = dimensions[i];
|
||||
if (d < 0) {
|
||||
return false;
|
||||
}
|
||||
dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
|
||||
if (dense_shape_size < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
shape->add_dimensions(d);
|
||||
minor_to_major->push_back(ndims - 1 - i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/* static */ ProgramShape ShapeUtil::MakeProgramShape(
|
||||
std::initializer_list<Shape> parameters, Shape result) {
|
||||
ProgramShape program_shape;
|
||||
@ -187,7 +250,9 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
|
||||
absl::Span<const int64> dimensions) {
|
||||
return MakeValidatedShape(element_type, dimensions).ValueOrDie();
|
||||
Shape shape;
|
||||
CHECK(FillNewShape(element_type, dimensions, &shape));
|
||||
return shape;
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
|
||||
@ -210,18 +275,31 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
|
||||
PrimitiveType element_type, absl::Span<const int64> dimensions) {
|
||||
CHECK(IsArrayPrimitiveType(element_type)) << element_type;
|
||||
Shape result;
|
||||
TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result));
|
||||
return result;
|
||||
Shape shape;
|
||||
if (!FillNewShape(element_type, dimensions, &shape)) {
|
||||
return InvalidArgument("invalid shape type=%d, dims=[%s]",
|
||||
static_cast<int>(element_type),
|
||||
absl::StrJoin(dimensions, ","));
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
|
||||
PrimitiveType element_type, absl::Span<const int64> dimensions,
|
||||
const std::vector<bool>& dynamic_dimensions) {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
MakeValidatedShape(element_type, dimensions));
|
||||
for (int i = 0; i < dynamic_dimensions.size(); ++i) {
|
||||
if (dynamic_dimensions.size() != dimensions.size()) {
|
||||
return InvalidArgument(
|
||||
"dynamic dimensions size %d did not match number of dimensions %d",
|
||||
dynamic_dimensions.size(), dimensions.size());
|
||||
}
|
||||
|
||||
Shape shape;
|
||||
if (!FillNewShape(element_type, dimensions, &shape)) {
|
||||
return InvalidArgument("invalid shape type=%d, dims=[%s]",
|
||||
static_cast<int>(element_type),
|
||||
absl::StrJoin(dimensions, ","));
|
||||
}
|
||||
for (int i = 0, n = dimensions.size(); i < n; i++) {
|
||||
shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
|
||||
}
|
||||
return shape;
|
||||
|
@ -792,6 +792,11 @@ class ShapeUtil {
|
||||
static bool CanUpcastIntegral(const Shape& from, const Shape& to);
|
||||
|
||||
private:
|
||||
// Fills *shape. Returns true on success.
|
||||
// REQUIRES: *shape is empty.
|
||||
static bool FillNewShape(PrimitiveType element_type,
|
||||
absl::Span<const int64> dimensions, Shape* shape);
|
||||
|
||||
// Validates the shape size is sane. This makes sure it's safe to do
|
||||
// calculations in int64 without overflowing.
|
||||
static Status ValidateShapeSize(const Shape& shape);
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
@ -827,5 +828,19 @@ TEST(AlignmentTest,
|
||||
EXPECT_FALSE(aligned_shape);
|
||||
}
|
||||
|
||||
void BM_MakeShape(::testing::benchmark::State& state) {
|
||||
for (auto s : state) {
|
||||
ShapeUtil::MakeShape(F32, {2});
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_MakeShape);
|
||||
|
||||
void BM_MakeValidatedShape(::testing::benchmark::State& state) {
|
||||
for (auto s : state) {
|
||||
ShapeUtil::MakeValidatedShape(F32, {2}).ValueOrDie();
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_MakeValidatedShape);
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -85,6 +85,7 @@ enum PrimitiveType {
|
||||
// Next = 19
|
||||
}
|
||||
// LINT.ThenChange(
|
||||
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,
|
||||
// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc
|
||||
// )
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user