Fix issues with shape inference during graph import.
Change: 144464718
This commit is contained in:
parent
7c96eadae6
commit
8f893368fc
tensorflow
core
python/framework
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -80,8 +81,7 @@ InferenceContext::InferenceContext(
|
||||
PostInputInit(input_handle_shapes, input_handle_dtypes);
|
||||
}
|
||||
|
||||
InferenceContext::~InferenceContext() {
|
||||
}
|
||||
InferenceContext::~InferenceContext() {}
|
||||
|
||||
Status InferenceContext::set_output(StringPiece output_name,
|
||||
const std::vector<ShapeHandle>& shapes) {
|
||||
@ -231,6 +231,11 @@ string InferenceContext::DebugString(DimensionHandle d) {
|
||||
return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
|
||||
}
|
||||
|
||||
string InferenceContext::DebugString() const {
|
||||
return strings::StrCat("InferenceContext for node: ",
|
||||
ProtoDebugString(node_def_));
|
||||
}
|
||||
|
||||
Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
|
||||
ShapeHandle* out) {
|
||||
const int32 existing = Rank(shape);
|
||||
|
@ -259,6 +259,9 @@ class InferenceContext {
|
||||
string DebugString(ShapeHandle s);
|
||||
string DebugString(DimensionHandle d);
|
||||
|
||||
// Describes the whole context, for debugging purposes.
|
||||
string DebugString() const;
|
||||
|
||||
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
|
||||
// the shape with asserted rank in <*out>. Otherwise return an error.
|
||||
//
|
||||
|
@ -97,10 +97,9 @@ Status PadShapeFn(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// tensor value was provided for paddings_t; doublecheck n_dim value is the
|
||||
// same.
|
||||
const auto num_dims = c->Value(n_dim);
|
||||
DCHECK_EQ(num_dims, paddings_t->shape().dim_size(0));
|
||||
const int64 num_dims = paddings_t->shape().dim_size(0);
|
||||
TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
|
||||
TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
|
||||
|
||||
if (paddings_t->dtype() == DT_INT32) {
|
||||
return PadKnown<int32>(c, input, paddings_t, num_dims);
|
||||
@ -440,7 +439,8 @@ REGISTER_OP("SplitV")
|
||||
c->set_output(i, output_shape);
|
||||
}
|
||||
} else {
|
||||
// Determine the output shape if split dimension and split sizes are known
|
||||
// Determine the output shape if split dimension and split sizes are
|
||||
// known.
|
||||
int64 split_dim = c->Value(split_dimension);
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
|
||||
std::vector<int64> data;
|
||||
@ -451,12 +451,12 @@ REGISTER_OP("SplitV")
|
||||
}
|
||||
if (num_outputs != data.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Length of size_splits should be equal to num_outputs");
|
||||
"Length of size_splits should be equal to num_outputs");
|
||||
}
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
output_shape = c->UnknownShapeOfRank(rank);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape));
|
||||
TF_RETURN_IF_ERROR(c->ReplaceDim(input, split_dim,
|
||||
c->MakeDim(data[i]), &output_shape));
|
||||
c->set_output(i, output_shape);
|
||||
}
|
||||
}
|
||||
@ -1370,8 +1370,8 @@ REGISTER_OP("GatherNd")
|
||||
if (c->Value(r_dim) > c->Rank(params)) {
|
||||
return errors::InvalidArgument(
|
||||
"indices.shape[-1] must be <= params.rank, but saw indices shape: ",
|
||||
c->DebugString(indices), " and params shape: ",
|
||||
c->DebugString(params));
|
||||
c->DebugString(indices),
|
||||
" and params shape: ", c->DebugString(params));
|
||||
}
|
||||
|
||||
// Remove r_dim from indices to get output.
|
||||
@ -1906,12 +1906,12 @@ REGISTER_OP("ReverseSequence")
|
||||
// Validate batch_dim and seq_dim against input.
|
||||
const int32 input_rank = c->Rank(input);
|
||||
if (batch_dim >= input_rank) {
|
||||
return errors::InvalidArgument("batch_dim must be < input rank: ",
|
||||
batch_dim, " vs. ", input_rank);
|
||||
return errors::InvalidArgument(
|
||||
"batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
|
||||
}
|
||||
if (seq_dim >= input_rank) {
|
||||
return errors::InvalidArgument("seq_dim must be < input rank: ",
|
||||
seq_dim, " vs. ", input_rank);
|
||||
return errors::InvalidArgument(
|
||||
"seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
|
||||
}
|
||||
|
||||
DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
|
||||
@ -3790,8 +3790,9 @@ REGISTER_OP("SpaceToDepth")
|
||||
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size,
|
||||
&output_depth));
|
||||
|
||||
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
|
||||
output_width, output_depth}));
|
||||
c->set_output(0,
|
||||
c->MakeShape({c->Dim(input, 0), output_height, output_width,
|
||||
output_depth}));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
@ -3895,8 +3896,9 @@ REGISTER_OP("DepthToSpace")
|
||||
TF_RETURN_IF_ERROR(c->Divide(c->Dim(input, 3), block_size * block_size,
|
||||
true /* evenly_divisible */, &output_depth));
|
||||
|
||||
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
|
||||
output_width, output_depth}));
|
||||
c->set_output(0,
|
||||
c->MakeShape({c->Dim(input, 0), output_height, output_width,
|
||||
output_depth}));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
@ -4772,8 +4774,9 @@ Status ScatterNdShape(InferenceContext* c) {
|
||||
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"The outer ", outer_dims, " dimensions of indices.shape=",
|
||||
c->DebugString(indices_shape), " must match the outer ", outer_dims,
|
||||
"The outer ", outer_dims,
|
||||
" dimensions of indices.shape=", c->DebugString(indices_shape),
|
||||
" must match the outer ", outer_dims,
|
||||
" dimensions of updates.shape=", c->DebugString(updates_shape),
|
||||
": ", s.error_message());
|
||||
}
|
||||
|
@ -331,6 +331,7 @@ TEST(ArrayOpsTest, PadD_ShapeFn) {
|
||||
INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
|
||||
INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
|
||||
INFER_OK(op, "?;[3,2]", "[?,?,?]");
|
||||
INFER_OK(op, "?;?", "[?,?,?]");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,6 +43,14 @@ Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
|
||||
width = c->UnknownDim();
|
||||
height = c->UnknownDim();
|
||||
} else {
|
||||
// TODO(petewarden) - Remove once we have constant evaluation in C++ only.
|
||||
if (size_tensor->dtype() != DT_INT32) {
|
||||
return errors::InvalidArgument(
|
||||
"Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
|
||||
"but got ",
|
||||
DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
|
||||
" in ", c->DebugString());
|
||||
}
|
||||
auto vec = size_tensor->vec<int32>();
|
||||
height = c->MakeDim(vec(0));
|
||||
width = c->MakeDim(vec(1));
|
||||
@ -74,8 +82,9 @@ Status DecodeImageShapeFn(InferenceContext* c) {
|
||||
channels_dim = c->MakeDim(channels);
|
||||
}
|
||||
|
||||
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim, channels_dim}));
|
||||
c->set_output(0,
|
||||
c->MakeShape({InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim, channels_dim}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -555,9 +564,10 @@ REGISTER_OP("DecodeGif")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim, 3}));
|
||||
c->set_output(0,
|
||||
c->MakeShape({InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim, 3}));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
|
@ -638,6 +638,11 @@ def _ConstantValue(tensor):
|
||||
return np.concatenate(values, axis=dim)
|
||||
elif tensor.op.type == "Pack":
|
||||
values = []
|
||||
# Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
|
||||
# and shouldn't be produced, but to deal sensibly with them here we check
|
||||
# and return None.
|
||||
if not tensor.op.inputs:
|
||||
return None
|
||||
for x in tensor.op.inputs:
|
||||
value = constant_value(x)
|
||||
if value is None:
|
||||
|
Loading…
Reference in New Issue
Block a user