Make sure TPUPartitionedInput shape inference doesn't crash if input handle shapes and types are not available.

PiperOrigin-RevId: 341777572
Change-Id: Iad741580d81a51de3d92861f8c999047bd4b163d
This commit is contained in:
Ruoxin Sang 2020-11-11 00:29:04 -08:00 committed by TensorFlower Gardener
parent fdad6c7ae4
commit bec7b3dae4

View File

@ -61,35 +61,40 @@ REGISTER_OP("TPUPartitionedInput")
// If this is a resource, unify the resource shapes.
if (dtype == DT_RESOURCE) {
ShapeHandle previous_shape_handle;
const std::vector<shape_inference::ShapeAndType>* shapes_and_types =
nullptr;
for (int i = c->num_inputs() - 1; i >= 0; --i) {
ShapeHandle shape_handle =
c->input_handle_shapes_and_types(i)->at(0).shape;
if (!c->FullyDefined(shape_handle)) {
return errors::InvalidArgument("Inputs must have static shape,",
"input[", i,
"] has unknown dimension.");
}
if (i != c->num_inputs() - 1) {
ShapeHandle tmp;
if (!c->Merge(shape_handle, previous_shape_handle, &tmp).ok()) {
return errors::InvalidArgument(
"Inputs must have the same shape.");
shapes_and_types = c->input_handle_shapes_and_types(i);
if (shapes_and_types) {
ShapeHandle shape_handle = shapes_and_types->at(0).shape;
if (!c->FullyDefined(shape_handle)) {
return errors::InvalidArgument("Inputs must have static shape,",
"input[", i,
"] has unknown dimension.");
}
if (i != c->num_inputs() - 1) {
ShapeHandle tmp;
if (!c->Merge(shape_handle, previous_shape_handle, &tmp).ok()) {
return errors::InvalidArgument(
"Inputs must have the same shape.");
}
} else {
previous_shape_handle = shape_handle;
}
} else {
previous_shape_handle = shape_handle;
}
}
if (partition_dim == -1) {
c->set_output_handle_shapes_and_types(
0, *c->input_handle_shapes_and_types(0));
} else {
ShapeHandle newoutput0 =
_UpdatePartitionDim(c, previous_shape_handle, partition_dim);
if (shapes_and_types) {
if (partition_dim == -1) {
c->set_output_handle_shapes_and_types(0, *shapes_and_types);
} else {
ShapeHandle newoutput0 =
_UpdatePartitionDim(c, previous_shape_handle, partition_dim);
std::vector<shape_inference::ShapeAndType> output_shapes_and_types;
output_shapes_and_types.push_back(shape_inference::ShapeAndType(
newoutput0, c->input_handle_shapes_and_types(0)->at(0).dtype));
c->set_output_handle_shapes_and_types(0, output_shapes_and_types);
std::vector<shape_inference::ShapeAndType> output_shapes_and_types;
output_shapes_and_types.push_back(shape_inference::ShapeAndType(
newoutput0, shapes_and_types->at(0).dtype));
c->set_output_handle_shapes_and_types(0, output_shapes_and_types);
}
}
}