[XLA:CPU] Assert that parameters have the correct shape/size.

If we have an HloModule, check that they're of the correct shape.  If we don't
have an HloModule, at least we can check that they have the correct size.

PiperOrigin-RevId: 254415709
This commit is contained in:
Justin Lebar 2019-06-21 09:57:03 -07:00 committed by TensorFlower Gardener
parent 6609962713
commit 8f4e309f2d

View File

@ -92,6 +92,9 @@ CpuExecutable::CreateBufferTable(
if (allocation.is_entry_computation_parameter()) {
unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
allocation.param_shape_index());
CHECK_EQ(allocation.size(), unowning_buffers[i].size())
<< "Size mismatch on param " << allocation.parameter_number()
<< " at shape index " << allocation.param_shape_index().ToString();
VLOG(3) << "allocation #" << i << " is a parameter";
continue;
}
@ -294,6 +297,21 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
return Unimplemented("Points-to set of root instruction is ambiguous");
}
if (hlo_module_) {
const HloComputation* entry_comp = hlo_module_->entry_computation();
CHECK_EQ(entry_comp->num_parameters(), arguments.size())
<< "Wrong number of arguments passed when running executable";
for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
const Shape& expected_shape =
entry_comp->parameter_instruction(i)->shape();
const Shape& actual_shape = arguments[i]->on_device_shape();
CHECK(expected_shape == actual_shape) << absl::StreamFormat(
"Shape mismatch on argument %d. Expected %s, but was %s.", i,
expected_shape.ToString(/*print_layout=*/true),
actual_shape.ToString(/*print_layout=*/true));
}
}
auto* host_stream = dynamic_cast<se::host::HostStream*>(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();