Fix issues with shape inference during graph import.

Change: 144464718
This commit is contained in:
Pete Warden 2017-01-13 11:27:45 -08:00 committed by TensorFlower Gardener
parent 7c96eadae6
commit 8f893368fc
6 changed files with 54 additions and 27 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/node_def.pb_text.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
@ -80,8 +81,7 @@ InferenceContext::InferenceContext(
PostInputInit(input_handle_shapes, input_handle_dtypes);
}
InferenceContext::~InferenceContext() {
}
InferenceContext::~InferenceContext() {}
Status InferenceContext::set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes) {
@ -231,6 +231,11 @@ string InferenceContext::DebugString(DimensionHandle d) {
return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
}
string InferenceContext::DebugString() const {
return strings::StrCat("InferenceContext for node: ",
ProtoDebugString(node_def_));
}
Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
ShapeHandle* out) {
const int32 existing = Rank(shape);

View File

@ -259,6 +259,9 @@ class InferenceContext {
string DebugString(ShapeHandle s);
string DebugString(DimensionHandle d);
// Describes the whole context, for debugging purposes.
string DebugString() const;
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
// the shape with asserted rank in <*out>. Otherwise return an error.
//

View File

@ -97,10 +97,9 @@ Status PadShapeFn(InferenceContext* c) {
return Status::OK();
}
// tensor value was provided for paddings_t; doublecheck n_dim value is the
// same.
const auto num_dims = c->Value(n_dim);
DCHECK_EQ(num_dims, paddings_t->shape().dim_size(0));
const int64 num_dims = paddings_t->shape().dim_size(0);
TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
if (paddings_t->dtype() == DT_INT32) {
return PadKnown<int32>(c, input, paddings_t, num_dims);
@ -440,7 +439,8 @@ REGISTER_OP("SplitV")
c->set_output(i, output_shape);
}
} else {
// Determine the output shape if split dimension and split sizes are known
// Determine the output shape if split dimension and split sizes are
// known.
int64 split_dim = c->Value(split_dimension);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
std::vector<int64> data;
@ -451,12 +451,12 @@ REGISTER_OP("SplitV")
}
if (num_outputs != data.size()) {
return errors::InvalidArgument(
"Length of size_splits should be equal to num_outputs");
"Length of size_splits should be equal to num_outputs");
}
for (int i = 0; i < num_outputs; ++i) {
output_shape = c->UnknownShapeOfRank(rank);
TF_RETURN_IF_ERROR(
c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape));
TF_RETURN_IF_ERROR(c->ReplaceDim(input, split_dim,
c->MakeDim(data[i]), &output_shape));
c->set_output(i, output_shape);
}
}
@ -1370,8 +1370,8 @@ REGISTER_OP("GatherNd")
if (c->Value(r_dim) > c->Rank(params)) {
return errors::InvalidArgument(
"indices.shape[-1] must be <= params.rank, but saw indices shape: ",
c->DebugString(indices), " and params shape: ",
c->DebugString(params));
c->DebugString(indices),
" and params shape: ", c->DebugString(params));
}
// Remove r_dim from indices to get output.
@ -1906,12 +1906,12 @@ REGISTER_OP("ReverseSequence")
// Validate batch_dim and seq_dim against input.
const int32 input_rank = c->Rank(input);
if (batch_dim >= input_rank) {
return errors::InvalidArgument("batch_dim must be < input rank: ",
batch_dim, " vs. ", input_rank);
return errors::InvalidArgument(
"batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
}
if (seq_dim >= input_rank) {
return errors::InvalidArgument("seq_dim must be < input rank: ",
seq_dim, " vs. ", input_rank);
return errors::InvalidArgument(
"seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
}
DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
@ -3790,8 +3790,9 @@ REGISTER_OP("SpaceToDepth")
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size,
&output_depth));
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
output_width, output_depth}));
c->set_output(0,
c->MakeShape({c->Dim(input, 0), output_height, output_width,
output_depth}));
return Status::OK();
})
.Doc(R"doc(
@ -3895,8 +3896,9 @@ REGISTER_OP("DepthToSpace")
TF_RETURN_IF_ERROR(c->Divide(c->Dim(input, 3), block_size * block_size,
true /* evenly_divisible */, &output_depth));
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
output_width, output_depth}));
c->set_output(0,
c->MakeShape({c->Dim(input, 0), output_height, output_width,
output_depth}));
return Status::OK();
})
.Doc(R"doc(
@ -4772,8 +4774,9 @@ Status ScatterNdShape(InferenceContext* c) {
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"The outer ", outer_dims, " dimensions of indices.shape=",
c->DebugString(indices_shape), " must match the outer ", outer_dims,
"The outer ", outer_dims,
" dimensions of indices.shape=", c->DebugString(indices_shape),
" must match the outer ", outer_dims,
" dimensions of updates.shape=", c->DebugString(updates_shape),
": ", s.error_message());
}

View File

@ -331,6 +331,7 @@ TEST(ArrayOpsTest, PadD_ShapeFn) {
INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
INFER_OK(op, "?;[3,2]", "[?,?,?]");
INFER_OK(op, "?;?", "[?,?,?]");
}
}

View File

@ -43,6 +43,14 @@ Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
width = c->UnknownDim();
height = c->UnknownDim();
} else {
// TODO(petewarden) - Remove once we have constant evaluation in C++ only.
if (size_tensor->dtype() != DT_INT32) {
return errors::InvalidArgument(
"Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
"but got ",
DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
" in ", c->DebugString());
}
auto vec = size_tensor->vec<int32>();
height = c->MakeDim(vec(0));
width = c->MakeDim(vec(1));
@ -74,8 +82,9 @@ Status DecodeImageShapeFn(InferenceContext* c) {
channels_dim = c->MakeDim(channels);
}
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim, channels_dim}));
c->set_output(0,
c->MakeShape({InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim, channels_dim}));
return Status::OK();
}
@ -555,9 +564,10 @@ REGISTER_OP("DecodeGif")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim, 3}));
c->set_output(0,
c->MakeShape({InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim, 3}));
return Status::OK();
})
.Doc(R"doc(

View File

@ -638,6 +638,11 @@ def _ConstantValue(tensor):
return np.concatenate(values, axis=dim)
elif tensor.op.type == "Pack":
values = []
# Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
# and shouldn't be produced, but to deal sensibly with them here we check
# and return None.
if not tensor.op.inputs:
return None
for x in tensor.op.inputs:
value = constant_value(x)
if value is None: