Make sure TPUPartitionedInput shape inference doesn't crash if input handle shapes and types are not available.
PiperOrigin-RevId: 341777572 Change-Id: Iad741580d81a51de3d92861f8c999047bd4b163d
This commit is contained in:
parent
fdad6c7ae4
commit
bec7b3dae4
@ -61,35 +61,40 @@ REGISTER_OP("TPUPartitionedInput")
|
||||
// If this is a resource, unify the resource shapes.
|
||||
if (dtype == DT_RESOURCE) {
|
||||
ShapeHandle previous_shape_handle;
|
||||
const std::vector<shape_inference::ShapeAndType>* shapes_and_types =
|
||||
nullptr;
|
||||
for (int i = c->num_inputs() - 1; i >= 0; --i) {
|
||||
ShapeHandle shape_handle =
|
||||
c->input_handle_shapes_and_types(i)->at(0).shape;
|
||||
if (!c->FullyDefined(shape_handle)) {
|
||||
return errors::InvalidArgument("Inputs must have static shape,",
|
||||
"input[", i,
|
||||
"] has unknown dimension.");
|
||||
}
|
||||
if (i != c->num_inputs() - 1) {
|
||||
ShapeHandle tmp;
|
||||
if (!c->Merge(shape_handle, previous_shape_handle, &tmp).ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"Inputs must have the same shape.");
|
||||
shapes_and_types = c->input_handle_shapes_and_types(i);
|
||||
if (shapes_and_types) {
|
||||
ShapeHandle shape_handle = shapes_and_types->at(0).shape;
|
||||
if (!c->FullyDefined(shape_handle)) {
|
||||
return errors::InvalidArgument("Inputs must have static shape,",
|
||||
"input[", i,
|
||||
"] has unknown dimension.");
|
||||
}
|
||||
if (i != c->num_inputs() - 1) {
|
||||
ShapeHandle tmp;
|
||||
if (!c->Merge(shape_handle, previous_shape_handle, &tmp).ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"Inputs must have the same shape.");
|
||||
}
|
||||
} else {
|
||||
previous_shape_handle = shape_handle;
|
||||
}
|
||||
} else {
|
||||
previous_shape_handle = shape_handle;
|
||||
}
|
||||
}
|
||||
if (partition_dim == -1) {
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, *c->input_handle_shapes_and_types(0));
|
||||
} else {
|
||||
ShapeHandle newoutput0 =
|
||||
_UpdatePartitionDim(c, previous_shape_handle, partition_dim);
|
||||
if (shapes_and_types) {
|
||||
if (partition_dim == -1) {
|
||||
c->set_output_handle_shapes_and_types(0, *shapes_and_types);
|
||||
} else {
|
||||
ShapeHandle newoutput0 =
|
||||
_UpdatePartitionDim(c, previous_shape_handle, partition_dim);
|
||||
|
||||
std::vector<shape_inference::ShapeAndType> output_shapes_and_types;
|
||||
output_shapes_and_types.push_back(shape_inference::ShapeAndType(
|
||||
newoutput0, c->input_handle_shapes_and_types(0)->at(0).dtype));
|
||||
c->set_output_handle_shapes_and_types(0, output_shapes_and_types);
|
||||
std::vector<shape_inference::ShapeAndType> output_shapes_and_types;
|
||||
output_shapes_and_types.push_back(shape_inference::ShapeAndType(
|
||||
newoutput0, shapes_and_types->at(0).dtype));
|
||||
c->set_output_handle_shapes_and_types(0, output_shapes_and_types);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user