Fix issues with shape inference during graph import.
Change: 144464718
This commit is contained in:
parent
7c96eadae6
commit
8f893368fc
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -80,8 +81,7 @@ InferenceContext::InferenceContext(
|
|||||||
PostInputInit(input_handle_shapes, input_handle_dtypes);
|
PostInputInit(input_handle_shapes, input_handle_dtypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
InferenceContext::~InferenceContext() {
|
InferenceContext::~InferenceContext() {}
|
||||||
}
|
|
||||||
|
|
||||||
Status InferenceContext::set_output(StringPiece output_name,
|
Status InferenceContext::set_output(StringPiece output_name,
|
||||||
const std::vector<ShapeHandle>& shapes) {
|
const std::vector<ShapeHandle>& shapes) {
|
||||||
@ -231,6 +231,11 @@ string InferenceContext::DebugString(DimensionHandle d) {
|
|||||||
return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
|
return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string InferenceContext::DebugString() const {
|
||||||
|
return strings::StrCat("InferenceContext for node: ",
|
||||||
|
ProtoDebugString(node_def_));
|
||||||
|
}
|
||||||
|
|
||||||
Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
|
Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
|
||||||
ShapeHandle* out) {
|
ShapeHandle* out) {
|
||||||
const int32 existing = Rank(shape);
|
const int32 existing = Rank(shape);
|
||||||
|
@ -259,6 +259,9 @@ class InferenceContext {
|
|||||||
string DebugString(ShapeHandle s);
|
string DebugString(ShapeHandle s);
|
||||||
string DebugString(DimensionHandle d);
|
string DebugString(DimensionHandle d);
|
||||||
|
|
||||||
|
// Describes the whole context, for debugging purposes.
|
||||||
|
string DebugString() const;
|
||||||
|
|
||||||
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
|
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
|
||||||
// the shape with asserted rank in <*out>. Otherwise return an error.
|
// the shape with asserted rank in <*out>. Otherwise return an error.
|
||||||
//
|
//
|
||||||
|
@ -97,10 +97,9 @@ Status PadShapeFn(InferenceContext* c) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// tensor value was provided for paddings_t; doublecheck n_dim value is the
|
const int64 num_dims = paddings_t->shape().dim_size(0);
|
||||||
// same.
|
TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
|
||||||
const auto num_dims = c->Value(n_dim);
|
TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
|
||||||
DCHECK_EQ(num_dims, paddings_t->shape().dim_size(0));
|
|
||||||
|
|
||||||
if (paddings_t->dtype() == DT_INT32) {
|
if (paddings_t->dtype() == DT_INT32) {
|
||||||
return PadKnown<int32>(c, input, paddings_t, num_dims);
|
return PadKnown<int32>(c, input, paddings_t, num_dims);
|
||||||
@ -440,7 +439,8 @@ REGISTER_OP("SplitV")
|
|||||||
c->set_output(i, output_shape);
|
c->set_output(i, output_shape);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Determine the output shape if split dimension and split sizes are known
|
// Determine the output shape if split dimension and split sizes are
|
||||||
|
// known.
|
||||||
int64 split_dim = c->Value(split_dimension);
|
int64 split_dim = c->Value(split_dimension);
|
||||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
|
||||||
std::vector<int64> data;
|
std::vector<int64> data;
|
||||||
@ -455,8 +455,8 @@ REGISTER_OP("SplitV")
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
output_shape = c->UnknownShapeOfRank(rank);
|
output_shape = c->UnknownShapeOfRank(rank);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(c->ReplaceDim(input, split_dim,
|
||||||
c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape));
|
c->MakeDim(data[i]), &output_shape));
|
||||||
c->set_output(i, output_shape);
|
c->set_output(i, output_shape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1370,8 +1370,8 @@ REGISTER_OP("GatherNd")
|
|||||||
if (c->Value(r_dim) > c->Rank(params)) {
|
if (c->Value(r_dim) > c->Rank(params)) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"indices.shape[-1] must be <= params.rank, but saw indices shape: ",
|
"indices.shape[-1] must be <= params.rank, but saw indices shape: ",
|
||||||
c->DebugString(indices), " and params shape: ",
|
c->DebugString(indices),
|
||||||
c->DebugString(params));
|
" and params shape: ", c->DebugString(params));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove r_dim from indices to get output.
|
// Remove r_dim from indices to get output.
|
||||||
@ -1906,12 +1906,12 @@ REGISTER_OP("ReverseSequence")
|
|||||||
// Validate batch_dim and seq_dim against input.
|
// Validate batch_dim and seq_dim against input.
|
||||||
const int32 input_rank = c->Rank(input);
|
const int32 input_rank = c->Rank(input);
|
||||||
if (batch_dim >= input_rank) {
|
if (batch_dim >= input_rank) {
|
||||||
return errors::InvalidArgument("batch_dim must be < input rank: ",
|
return errors::InvalidArgument(
|
||||||
batch_dim, " vs. ", input_rank);
|
"batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
|
||||||
}
|
}
|
||||||
if (seq_dim >= input_rank) {
|
if (seq_dim >= input_rank) {
|
||||||
return errors::InvalidArgument("seq_dim must be < input rank: ",
|
return errors::InvalidArgument(
|
||||||
seq_dim, " vs. ", input_rank);
|
"seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
|
DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
|
||||||
@ -3790,8 +3790,9 @@ REGISTER_OP("SpaceToDepth")
|
|||||||
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size,
|
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size,
|
||||||
&output_depth));
|
&output_depth));
|
||||||
|
|
||||||
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
|
c->set_output(0,
|
||||||
output_width, output_depth}));
|
c->MakeShape({c->Dim(input, 0), output_height, output_width,
|
||||||
|
output_depth}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
@ -3895,8 +3896,9 @@ REGISTER_OP("DepthToSpace")
|
|||||||
TF_RETURN_IF_ERROR(c->Divide(c->Dim(input, 3), block_size * block_size,
|
TF_RETURN_IF_ERROR(c->Divide(c->Dim(input, 3), block_size * block_size,
|
||||||
true /* evenly_divisible */, &output_depth));
|
true /* evenly_divisible */, &output_depth));
|
||||||
|
|
||||||
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
|
c->set_output(0,
|
||||||
output_width, output_depth}));
|
c->MakeShape({c->Dim(input, 0), output_height, output_width,
|
||||||
|
output_depth}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
@ -4772,8 +4774,9 @@ Status ScatterNdShape(InferenceContext* c) {
|
|||||||
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
|
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"The outer ", outer_dims, " dimensions of indices.shape=",
|
"The outer ", outer_dims,
|
||||||
c->DebugString(indices_shape), " must match the outer ", outer_dims,
|
" dimensions of indices.shape=", c->DebugString(indices_shape),
|
||||||
|
" must match the outer ", outer_dims,
|
||||||
" dimensions of updates.shape=", c->DebugString(updates_shape),
|
" dimensions of updates.shape=", c->DebugString(updates_shape),
|
||||||
": ", s.error_message());
|
": ", s.error_message());
|
||||||
}
|
}
|
||||||
|
@ -331,6 +331,7 @@ TEST(ArrayOpsTest, PadD_ShapeFn) {
|
|||||||
INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
|
INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
|
||||||
INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
|
INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
|
||||||
INFER_OK(op, "?;[3,2]", "[?,?,?]");
|
INFER_OK(op, "?;[3,2]", "[?,?,?]");
|
||||||
|
INFER_OK(op, "?;?", "[?,?,?]");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +43,14 @@ Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
|
|||||||
width = c->UnknownDim();
|
width = c->UnknownDim();
|
||||||
height = c->UnknownDim();
|
height = c->UnknownDim();
|
||||||
} else {
|
} else {
|
||||||
|
// TODO(petewarden) - Remove once we have constant evaluation in C++ only.
|
||||||
|
if (size_tensor->dtype() != DT_INT32) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
|
||||||
|
"but got ",
|
||||||
|
DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
|
||||||
|
" in ", c->DebugString());
|
||||||
|
}
|
||||||
auto vec = size_tensor->vec<int32>();
|
auto vec = size_tensor->vec<int32>();
|
||||||
height = c->MakeDim(vec(0));
|
height = c->MakeDim(vec(0));
|
||||||
width = c->MakeDim(vec(1));
|
width = c->MakeDim(vec(1));
|
||||||
@ -74,7 +82,8 @@ Status DecodeImageShapeFn(InferenceContext* c) {
|
|||||||
channels_dim = c->MakeDim(channels);
|
channels_dim = c->MakeDim(channels);
|
||||||
}
|
}
|
||||||
|
|
||||||
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
|
c->set_output(0,
|
||||||
|
c->MakeShape({InferenceContext::kUnknownDim,
|
||||||
InferenceContext::kUnknownDim, channels_dim}));
|
InferenceContext::kUnknownDim, channels_dim}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -555,7 +564,8 @@ REGISTER_OP("DecodeGif")
|
|||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
ShapeHandle unused;
|
ShapeHandle unused;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||||
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
|
c->set_output(0,
|
||||||
|
c->MakeShape({InferenceContext::kUnknownDim,
|
||||||
InferenceContext::kUnknownDim,
|
InferenceContext::kUnknownDim,
|
||||||
InferenceContext::kUnknownDim, 3}));
|
InferenceContext::kUnknownDim, 3}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -638,6 +638,11 @@ def _ConstantValue(tensor):
|
|||||||
return np.concatenate(values, axis=dim)
|
return np.concatenate(values, axis=dim)
|
||||||
elif tensor.op.type == "Pack":
|
elif tensor.op.type == "Pack":
|
||||||
values = []
|
values = []
|
||||||
|
# Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
|
||||||
|
# and shouldn't be produced, but to deal sensibly with them here we check
|
||||||
|
# and return None.
|
||||||
|
if not tensor.op.inputs:
|
||||||
|
return None
|
||||||
for x in tensor.op.inputs:
|
for x in tensor.op.inputs:
|
||||||
value = constant_value(x)
|
value = constant_value(x)
|
||||||
if value is None:
|
if value is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user