Fix issues with shape inference during graph import.

Change: 144464718
This commit is contained in:
Pete Warden 2017-01-13 11:27:45 -08:00 committed by TensorFlower Gardener
parent 7c96eadae6
commit 8f893368fc
6 changed files with 54 additions and 27 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/node_def.pb_text.h"
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
@ -80,8 +81,7 @@ InferenceContext::InferenceContext(
PostInputInit(input_handle_shapes, input_handle_dtypes); PostInputInit(input_handle_shapes, input_handle_dtypes);
} }
InferenceContext::~InferenceContext() { InferenceContext::~InferenceContext() {}
}
Status InferenceContext::set_output(StringPiece output_name, Status InferenceContext::set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes) { const std::vector<ShapeHandle>& shapes) {
@ -231,6 +231,11 @@ string InferenceContext::DebugString(DimensionHandle d) {
return ValueKnown(d) ? strings::StrCat(Value(d)) : "?"; return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
} }
string InferenceContext::DebugString() const {
return strings::StrCat("InferenceContext for node: ",
ProtoDebugString(node_def_));
}
Status InferenceContext::WithRank(ShapeHandle shape, int32 rank, Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
ShapeHandle* out) { ShapeHandle* out) {
const int32 existing = Rank(shape); const int32 existing = Rank(shape);

View File

@ -259,6 +259,9 @@ class InferenceContext {
string DebugString(ShapeHandle s); string DebugString(ShapeHandle s);
string DebugString(DimensionHandle d); string DebugString(DimensionHandle d);
// Describes the whole context, for debugging purposes.
string DebugString() const;
// If <shape> has rank <rank>, or its rank is unknown, return OK and return // If <shape> has rank <rank>, or its rank is unknown, return OK and return
// the shape with asserted rank in <*out>. Otherwise return an error. // the shape with asserted rank in <*out>. Otherwise return an error.
// //

View File

@ -97,10 +97,9 @@ Status PadShapeFn(InferenceContext* c) {
return Status::OK(); return Status::OK();
} }
// tensor value was provided for paddings_t; doublecheck n_dim value is the const int64 num_dims = paddings_t->shape().dim_size(0);
// same. TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
const auto num_dims = c->Value(n_dim); TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
DCHECK_EQ(num_dims, paddings_t->shape().dim_size(0));
if (paddings_t->dtype() == DT_INT32) { if (paddings_t->dtype() == DT_INT32) {
return PadKnown<int32>(c, input, paddings_t, num_dims); return PadKnown<int32>(c, input, paddings_t, num_dims);
@ -440,7 +439,8 @@ REGISTER_OP("SplitV")
c->set_output(i, output_shape); c->set_output(i, output_shape);
} }
} else { } else {
// Determine the output shape if split dimension and split sizes are known // Determine the output shape if split dimension and split sizes are
// known.
int64 split_dim = c->Value(split_dimension); int64 split_dim = c->Value(split_dimension);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
std::vector<int64> data; std::vector<int64> data;
@ -455,8 +455,8 @@ REGISTER_OP("SplitV")
} }
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
output_shape = c->UnknownShapeOfRank(rank); output_shape = c->UnknownShapeOfRank(rank);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(c->ReplaceDim(input, split_dim,
c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape)); c->MakeDim(data[i]), &output_shape));
c->set_output(i, output_shape); c->set_output(i, output_shape);
} }
} }
@ -1370,8 +1370,8 @@ REGISTER_OP("GatherNd")
if (c->Value(r_dim) > c->Rank(params)) { if (c->Value(r_dim) > c->Rank(params)) {
return errors::InvalidArgument( return errors::InvalidArgument(
"indices.shape[-1] must be <= params.rank, but saw indices shape: ", "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
c->DebugString(indices), " and params shape: ", c->DebugString(indices),
c->DebugString(params)); " and params shape: ", c->DebugString(params));
} }
// Remove r_dim from indices to get output. // Remove r_dim from indices to get output.
@ -1906,12 +1906,12 @@ REGISTER_OP("ReverseSequence")
// Validate batch_dim and seq_dim against input. // Validate batch_dim and seq_dim against input.
const int32 input_rank = c->Rank(input); const int32 input_rank = c->Rank(input);
if (batch_dim >= input_rank) { if (batch_dim >= input_rank) {
return errors::InvalidArgument("batch_dim must be < input rank: ", return errors::InvalidArgument(
batch_dim, " vs. ", input_rank); "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
} }
if (seq_dim >= input_rank) { if (seq_dim >= input_rank) {
return errors::InvalidArgument("seq_dim must be < input rank: ", return errors::InvalidArgument(
seq_dim, " vs. ", input_rank); "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
} }
DimensionHandle batch_dim_dim = c->Dim(input, batch_dim); DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
@ -3790,8 +3790,9 @@ REGISTER_OP("SpaceToDepth")
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size, TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size,
&output_depth)); &output_depth));
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height, c->set_output(0,
output_width, output_depth})); c->MakeShape({c->Dim(input, 0), output_height, output_width,
output_depth}));
return Status::OK(); return Status::OK();
}) })
.Doc(R"doc( .Doc(R"doc(
@ -3895,8 +3896,9 @@ REGISTER_OP("DepthToSpace")
TF_RETURN_IF_ERROR(c->Divide(c->Dim(input, 3), block_size * block_size, TF_RETURN_IF_ERROR(c->Divide(c->Dim(input, 3), block_size * block_size,
true /* evenly_divisible */, &output_depth)); true /* evenly_divisible */, &output_depth));
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height, c->set_output(0,
output_width, output_depth})); c->MakeShape({c->Dim(input, 0), output_height, output_width,
output_depth}));
return Status::OK(); return Status::OK();
}) })
.Doc(R"doc( .Doc(R"doc(
@ -4772,8 +4774,9 @@ Status ScatterNdShape(InferenceContext* c) {
Status s = c->Merge(prefix_indices, prefix_updates, &unused); Status s = c->Merge(prefix_indices, prefix_updates, &unused);
if (!s.ok()) { if (!s.ok()) {
return errors::InvalidArgument( return errors::InvalidArgument(
"The outer ", outer_dims, " dimensions of indices.shape=", "The outer ", outer_dims,
c->DebugString(indices_shape), " must match the outer ", outer_dims, " dimensions of indices.shape=", c->DebugString(indices_shape),
" must match the outer ", outer_dims,
" dimensions of updates.shape=", c->DebugString(updates_shape), " dimensions of updates.shape=", c->DebugString(updates_shape),
": ", s.error_message()); ": ", s.error_message());
} }

View File

@ -331,6 +331,7 @@ TEST(ArrayOpsTest, PadD_ShapeFn) {
INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]"); INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]"); INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
INFER_OK(op, "?;[3,2]", "[?,?,?]"); INFER_OK(op, "?;[3,2]", "[?,?,?]");
INFER_OK(op, "?;?", "[?,?,?]");
} }
} }

View File

@ -43,6 +43,14 @@ Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
width = c->UnknownDim(); width = c->UnknownDim();
height = c->UnknownDim(); height = c->UnknownDim();
} else { } else {
// TODO(petewarden) - Remove once we have constant evaluation in C++ only.
if (size_tensor->dtype() != DT_INT32) {
return errors::InvalidArgument(
"Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
"but got ",
DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
" in ", c->DebugString());
}
auto vec = size_tensor->vec<int32>(); auto vec = size_tensor->vec<int32>();
height = c->MakeDim(vec(0)); height = c->MakeDim(vec(0));
width = c->MakeDim(vec(1)); width = c->MakeDim(vec(1));
@ -74,7 +82,8 @@ Status DecodeImageShapeFn(InferenceContext* c) {
channels_dim = c->MakeDim(channels); channels_dim = c->MakeDim(channels);
} }
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim, c->set_output(0,
c->MakeShape({InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim, channels_dim})); InferenceContext::kUnknownDim, channels_dim}));
return Status::OK(); return Status::OK();
} }
@ -555,7 +564,8 @@ REGISTER_OP("DecodeGif")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
ShapeHandle unused; ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim, c->set_output(0,
c->MakeShape({InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim, InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim, 3})); InferenceContext::kUnknownDim, 3}));
return Status::OK(); return Status::OK();

View File

@ -638,6 +638,11 @@ def _ConstantValue(tensor):
return np.concatenate(values, axis=dim) return np.concatenate(values, axis=dim)
elif tensor.op.type == "Pack": elif tensor.op.type == "Pack":
values = [] values = []
# Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
# and shouldn't be produced, but to deal sensibly with them here we check
# and return None.
if not tensor.op.inputs:
return None
for x in tensor.op.inputs: for x in tensor.op.inputs:
value = constant_value(x) value = constant_value(x)
if value is None: if value is None: