[XLA:Python] Validate shapes in Python bindings to avoid crashes.

[JAX] Perform LAPACK workspace calculations in int64 to avoid overflows, clamp the values passed to lapack to int32.

Will fix https://github.com/google/jax/issues/4358 when incorporated into a jaxlib.

PiperOrigin-RevId: 337367394
Change-Id: I3b8c116c7bfb764751448ab33ee7ae2a1ebe5ab6
This commit is contained in:
Peter Hawkins 2020-10-15 13:09:37 -07:00 committed by TensorFlower Gardener
parent 6a7d2bcaa9
commit 35f478825b
2 changed files with 40 additions and 10 deletions

View File

@ -206,6 +206,23 @@ bool IsOptimizedBuild() {
#endif // NDEBUG
}
// Safe version of ShapeUtil::MakeShapeWithLayout that fails gracefully on
// invalid input.
StatusOr<Shape> MakeShapeWithLayout(
PrimitiveType element_type, absl::Span<const int64> dims,
absl::optional<absl::Span<const int64>> minor_to_major) {
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeUtil::MakeValidatedShape(element_type, dims));
if (minor_to_major) {
*shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major);
TF_RETURN_IF_ERROR(
LayoutUtil::ValidateLayoutForShape(shape.layout(), shape));
} else {
shape.clear_layout();
}
return shape;
}
} // namespace
PYBIND11_MODULE(xla_extension, m) {
@ -262,15 +279,13 @@ PYBIND11_MODULE(xla_extension, m) {
.def_static(
"array_shape",
[](PrimitiveType type, py::object dims_seq,
absl::optional<py::object> layout_seq) -> Shape {
absl::optional<py::object> layout_seq) -> StatusOr<Shape> {
std::vector<int64> dims = IntSequenceToVector(dims_seq);
if (layout_seq) {
std::vector<int64> layout = IntSequenceToVector(*layout_seq);
return ShapeUtil::MakeShapeWithLayout(type, dims, layout);
return MakeShapeWithLayout(type, dims, layout);
} else {
Shape shape = ShapeUtil::MakeShape(type, dims);
shape.clear_layout();
return shape;
return MakeShapeWithLayout(type, dims, absl::nullopt);
}
},
"Constructs an array shape.", py::arg("type"), py::arg("dims"),
@ -278,16 +293,14 @@ PYBIND11_MODULE(xla_extension, m) {
.def_static(
"array_shape",
[](py::dtype dtype, py::object dims_seq,
absl::optional<py::object> layout_seq) -> Shape {
absl::optional<py::object> layout_seq) -> StatusOr<Shape> {
PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
std::vector<int64> dims = IntSequenceToVector(dims_seq);
if (layout_seq) {
std::vector<int64> layout = IntSequenceToVector(*layout_seq);
return ShapeUtil::MakeShapeWithLayout(type, dims, layout);
return MakeShapeWithLayout(type, dims, layout);
} else {
Shape shape = ShapeUtil::MakeShape(type, dims);
shape.clear_layout();
return shape;
return MakeShapeWithLayout(type, dims, absl::nullopt);
}
},
"Constructs an array shape.", py::arg("type"), py::arg("dims"),

View File

@ -35,6 +35,23 @@ except ImportError:
ops = xla_client.ops
class ShapeTest(absltest.TestCase):
def testInvalidShapes(self):
with self.assertRaisesRegex(RuntimeError,
"shape's dimensions must not be < 0.*"):
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4])
with self.assertRaisesRegex(
RuntimeError, "layout minor_to_major field contains 1 element.*"):
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3])
with self.assertRaisesRegex(
RuntimeError, "layout minor_to_major field has out-of-bounds value.*"):
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4],
[1, -1])
class ComputationPrinting(absltest.TestCase):
def ExampleComputation(self):