[XLA:CPU] Assert that parameters have the correct shape/size.
If we have an HloModule, check that they're of the correct shape. If we don't have an HloModule, at least we can check that they have the correct size. PiperOrigin-RevId: 254415709
This commit is contained in:
parent
6609962713
commit
8f4e309f2d
@ -92,6 +92,9 @@ CpuExecutable::CreateBufferTable(
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
|
||||
allocation.param_shape_index());
|
||||
CHECK_EQ(allocation.size(), unowning_buffers[i].size())
|
||||
<< "Size mismatch on param " << allocation.parameter_number()
|
||||
<< " at shape index " << allocation.param_shape_index().ToString();
|
||||
VLOG(3) << "allocation #" << i << " is a parameter";
|
||||
continue;
|
||||
}
|
||||
@ -294,6 +297,21 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
|
||||
return Unimplemented("Points-to set of root instruction is ambiguous");
|
||||
}
|
||||
|
||||
if (hlo_module_) {
|
||||
const HloComputation* entry_comp = hlo_module_->entry_computation();
|
||||
CHECK_EQ(entry_comp->num_parameters(), arguments.size())
|
||||
<< "Wrong number of arguments passed when running executable";
|
||||
for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
|
||||
const Shape& expected_shape =
|
||||
entry_comp->parameter_instruction(i)->shape();
|
||||
const Shape& actual_shape = arguments[i]->on_device_shape();
|
||||
CHECK(expected_shape == actual_shape) << absl::StreamFormat(
|
||||
"Shape mismatch on argument %d. Expected %s, but was %s.", i,
|
||||
expected_shape.ToString(/*print_layout=*/true),
|
||||
actual_shape.ToString(/*print_layout=*/true));
|
||||
}
|
||||
}
|
||||
|
||||
auto* host_stream = dynamic_cast<se::host::HostStream*>(
|
||||
run_options->stream()->implementation());
|
||||
se::Stream* stream = run_options->stream();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user