Move ValidateSparseTensor to common_shape_fns.h.

Change: 139112567
This commit is contained in:
A. Unique TensorFlower 2016-11-14 13:22:48 -08:00 committed by TensorFlower Gardener
parent 750c98508c
commit 0c20187645
7 changed files with 202 additions and 187 deletions

View File

@ -173,8 +173,8 @@ REGISTER_OP("DenseToSparseSetOperation")
} else { } else {
output_rank = c->UnknownDim(); output_rank = c->UnknownDim();
} }
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
c->ValidateSparseTensor(c->input(1), c->input(2), c->input(3))); c, c->input(1), c->input(2), c->input(3)));
DimensionHandle output_num_elements = c->Dim(input0_shape, 0); DimensionHandle output_num_elements = c->Dim(input0_shape, 0);
if (!c->ValueKnown(output_num_elements)) { if (!c->ValueKnown(output_num_elements)) {
output_num_elements = c->UnknownDim(); output_num_elements = c->UnknownDim();
@ -239,10 +239,10 @@ REGISTER_OP("SparseToSparseSetOperation")
} }
// The following should stay in sync with `ComputeSparseToSparse` shape // The following should stay in sync with `ComputeSparseToSparse` shape
// assertions in kernels/set_kernels.cc. // assertions in kernels/set_kernels.cc.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2))); c, c->input(0), c->input(1), c->input(2)));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5))); c, c->input(3), c->input(4), c->input(5)));
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim())); c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
c->set_output(1, c->Vector(c->UnknownDim())); c->set_output(1, c->Vector(c->UnknownDim()));
c->set_output(2, c->Vector(c->UnknownDim())); c->set_output(2, c->Vector(c->UnknownDim()));

View File

@ -860,5 +860,46 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
return Status::OK(); return Status::OK();
} }
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape) {
// Validate ranks.
ShapeHandle unused_shape;
TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
// Number of elements in indices and values must match.
DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
if (c->ValueKnown(num_index_elements_dim)) {
DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
if (c->ValueKnown(num_values_elements_dim)) {
int64 num_index_elements = c->Value(num_index_elements_dim);
int64 num_values_elements = c->Value(num_values_elements_dim);
if (num_index_elements != num_values_elements) {
return errors::InvalidArgument("Number of elements in index (",
num_index_elements, ") and values (",
num_values_elements, ") do not match.");
}
}
}
// Rank embedded in indices must match shape.
DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
if (c->ValueKnown(index_rank_dim)) {
DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
if (c->ValueKnown(shape_rank_dim)) {
int64 index_rank = c->Value(index_rank_dim);
int32 shape_rank = c->Value(shape_rank_dim);
if (index_rank != shape_rank) {
return errors::InvalidArgument("Index rank (", index_rank,
") and shape rank (", shape_rank,
") do not match.");
}
}
}
return Status::OK();
}
} // namespace shape_inference } // namespace shape_inference
} // namespace tensorflow } // namespace tensorflow

View File

@ -203,6 +203,11 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c);
// Tested by ops/math_ops_test.cc. // Tested by ops/math_ops_test.cc.
Status BroadcastBinaryOpShapeFn(InferenceContext* c); Status BroadcastBinaryOpShapeFn(InferenceContext* c);
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape);
} // namespace shape_inference } // namespace shape_inference
} // namespace tensorflow } // namespace tensorflow

View File

@ -40,6 +40,19 @@ TensorShapeProto Unknown() {
return ret; return ret;
} }
OpDef MakeOpDef(int num_inputs, int num_outputs) {
OpRegistrationData op_reg_data;
OpDefBuilder b("dummy");
for (int i = 0; i < num_inputs; ++i) {
b.Input(strings::StrCat("i", i, ": float"));
}
for (int i = 0; i < num_outputs; ++i) {
b.Output(strings::StrCat("o", i, ": float"));
}
CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
return op_reg_data.op_def;
}
} // namespace } // namespace
TEST(CommonShapeFnsTest, NoOutputShapeTest) { TEST(CommonShapeFnsTest, NoOutputShapeTest) {
@ -840,5 +853,138 @@ TEST(CommonShapeFnsTest, ReduceForReduceJoin_ShapeFn) {
INFER_OK(op, "[?,?,?];[2]", "?"); INFER_OK(op, "[?,?,?];[2]", "?");
} }
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
{}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
}
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
}
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
EXPECT_EQ(error::INVALID_ARGUMENT,
ValidateSparseTensor(&c, indices, values, shape).code());
}
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
EXPECT_EQ(error::INVALID_ARGUMENT,
ValidateSparseTensor(&c, indices, values, shape).code());
}
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
EXPECT_EQ(error::INVALID_ARGUMENT,
ValidateSparseTensor(&c, indices, values, shape).code());
}
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
}
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
}
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
}
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
}
TEST(CommonShapeFnsTest, ValidateSparseTensor) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
}
} // namespace shape_inference } // namespace shape_inference
} // namespace tensorflow } // namespace tensorflow

View File

@ -424,50 +424,6 @@ class InferenceContext {
return output_handle_dtype_[idx]; return output_handle_dtype_[idx];
} }
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(ShapeHandle indices_shape,
ShapeHandle values_shape,
ShapeHandle shape_shape) {
// Validate ranks.
ShapeHandle unused_shape;
TF_RETURN_IF_ERROR(WithRank(indices_shape, 2, &unused_shape));
TF_RETURN_IF_ERROR(WithRank(values_shape, 1, &unused_shape));
TF_RETURN_IF_ERROR(WithRank(shape_shape, 1, &unused_shape));
// Number of elements in indices and values must match.
DimensionHandle num_index_elements_dim = Dim(indices_shape, 0);
if (ValueKnown(num_index_elements_dim)) {
DimensionHandle num_values_elements_dim = Dim(values_shape, 0);
if (ValueKnown(num_values_elements_dim)) {
int64 num_index_elements = Value(num_index_elements_dim);
int64 num_values_elements = Value(num_values_elements_dim);
if (num_index_elements != num_values_elements) {
return errors::InvalidArgument(
"Number of elements in index (", num_index_elements,
") and values (", num_values_elements, ") do not match.");
}
}
}
// Rank embedded in indices must match shape.
DimensionHandle index_rank_dim = Dim(indices_shape, 1);
if (ValueKnown(index_rank_dim)) {
DimensionHandle shape_rank_dim = Dim(shape_shape, 0);
if (ValueKnown(shape_rank_dim)) {
int64 index_rank = Value(index_rank_dim);
int32 shape_rank = Value(shape_rank_dim);
if (index_rank != shape_rank) {
return errors::InvalidArgument("Index rank (", index_rank,
") and shape rank (", shape_rank,
") do not match.");
}
}
}
return Status::OK();
}
// Note that shape functions should usually call MakeShapeFromShapeTensor, // Note that shape functions should usually call MakeShapeFromShapeTensor,
// as it does more analysis to provide partial shapes. // as it does more analysis to provide partial shapes.
// //

View File

@ -1264,138 +1264,5 @@ TEST_F(ShapeInferenceTest, Max) {
EXPECT_TRUE(SameHandle(d_2, out)); EXPECT_TRUE(SameHandle(d_2, out));
} }
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
{}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
}
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
}
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
EXPECT_EQ(error::INVALID_ARGUMENT,
c.ValidateSparseTensor(indices, values, shape).code());
}
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
EXPECT_EQ(error::INVALID_ARGUMENT,
c.ValidateSparseTensor(indices, values, shape).code());
}
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
EXPECT_EQ(error::INVALID_ARGUMENT,
c.ValidateSparseTensor(indices, values, shape).code());
}
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
}
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
}
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
}
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
}
TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
auto indices = c.input(0);
auto values = c.input(1);
auto shape = c.input(2);
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
}
} // namespace shape_inference } // namespace shape_inference
} // namespace tensorflow } // namespace tensorflow

View File

@ -997,10 +997,10 @@ REGISTER_OP("EditDistance")
.Attr("T: type") .Attr("T: type")
.Output("output: float") .Output("output: float")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2))); c, c->input(0), c->input(1), c->input(2)));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5))); c, c->input(3), c->input(4), c->input(5)));
const Tensor* hypothesis_shape_t = c->input_tensor(2); const Tensor* hypothesis_shape_t = c->input_tensor(2);
const Tensor* truth_shape_t = c->input_tensor(5); const Tensor* truth_shape_t = c->input_tensor(5);
if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) { if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {