[XLA:Python] Validate shapes in Python bindings to avoid crashes.
[JAX] Perform LAPACK workspace calculations in int64 to avoid overflows, clamp the values passed to lapack to int32. Will fix https://github.com/google/jax/issues/4358 when incorporated into a jaxlib. PiperOrigin-RevId: 337367394 Change-Id: I3b8c116c7bfb764751448ab33ee7ae2a1ebe5ab6
This commit is contained in:
parent
6a7d2bcaa9
commit
35f478825b
@ -206,6 +206,23 @@ bool IsOptimizedBuild() {
|
||||
#endif // NDEBUG
|
||||
}
|
||||
|
||||
// Safe version of ShapeUtil::MakeShapeWithLayout that fails gracefully on
|
||||
// invalid input.
|
||||
StatusOr<Shape> MakeShapeWithLayout(
|
||||
PrimitiveType element_type, absl::Span<const int64> dims,
|
||||
absl::optional<absl::Span<const int64>> minor_to_major) {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeUtil::MakeValidatedShape(element_type, dims));
|
||||
if (minor_to_major) {
|
||||
*shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major);
|
||||
TF_RETURN_IF_ERROR(
|
||||
LayoutUtil::ValidateLayoutForShape(shape.layout(), shape));
|
||||
} else {
|
||||
shape.clear_layout();
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(xla_extension, m) {
|
||||
@ -262,15 +279,13 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def_static(
|
||||
"array_shape",
|
||||
[](PrimitiveType type, py::object dims_seq,
|
||||
absl::optional<py::object> layout_seq) -> Shape {
|
||||
absl::optional<py::object> layout_seq) -> StatusOr<Shape> {
|
||||
std::vector<int64> dims = IntSequenceToVector(dims_seq);
|
||||
if (layout_seq) {
|
||||
std::vector<int64> layout = IntSequenceToVector(*layout_seq);
|
||||
return ShapeUtil::MakeShapeWithLayout(type, dims, layout);
|
||||
return MakeShapeWithLayout(type, dims, layout);
|
||||
} else {
|
||||
Shape shape = ShapeUtil::MakeShape(type, dims);
|
||||
shape.clear_layout();
|
||||
return shape;
|
||||
return MakeShapeWithLayout(type, dims, absl::nullopt);
|
||||
}
|
||||
},
|
||||
"Constructs an array shape.", py::arg("type"), py::arg("dims"),
|
||||
@ -278,16 +293,14 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def_static(
|
||||
"array_shape",
|
||||
[](py::dtype dtype, py::object dims_seq,
|
||||
absl::optional<py::object> layout_seq) -> Shape {
|
||||
absl::optional<py::object> layout_seq) -> StatusOr<Shape> {
|
||||
PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
|
||||
std::vector<int64> dims = IntSequenceToVector(dims_seq);
|
||||
if (layout_seq) {
|
||||
std::vector<int64> layout = IntSequenceToVector(*layout_seq);
|
||||
return ShapeUtil::MakeShapeWithLayout(type, dims, layout);
|
||||
return MakeShapeWithLayout(type, dims, layout);
|
||||
} else {
|
||||
Shape shape = ShapeUtil::MakeShape(type, dims);
|
||||
shape.clear_layout();
|
||||
return shape;
|
||||
return MakeShapeWithLayout(type, dims, absl::nullopt);
|
||||
}
|
||||
},
|
||||
"Constructs an array shape.", py::arg("type"), py::arg("dims"),
|
||||
|
||||
@ -35,6 +35,23 @@ except ImportError:
|
||||
ops = xla_client.ops
|
||||
|
||||
|
||||
class ShapeTest(absltest.TestCase):
|
||||
|
||||
def testInvalidShapes(self):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"shape's dimensions must not be < 0.*"):
|
||||
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "layout minor_to_major field contains 1 element.*"):
|
||||
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "layout minor_to_major field has out-of-bounds value.*"):
|
||||
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4],
|
||||
[1, -1])
|
||||
|
||||
|
||||
class ComputationPrinting(absltest.TestCase):
|
||||
|
||||
def ExampleComputation(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user