Move ValidateSparseTensor to common_shape_fns.h.
Change: 139112567
This commit is contained in:
parent
750c98508c
commit
0c20187645
@ -173,8 +173,8 @@ REGISTER_OP("DenseToSparseSetOperation")
|
||||
} else {
|
||||
output_rank = c->UnknownDim();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ValidateSparseTensor(c->input(1), c->input(2), c->input(3)));
|
||||
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||
c, c->input(1), c->input(2), c->input(3)));
|
||||
DimensionHandle output_num_elements = c->Dim(input0_shape, 0);
|
||||
if (!c->ValueKnown(output_num_elements)) {
|
||||
output_num_elements = c->UnknownDim();
|
||||
@ -239,10 +239,10 @@ REGISTER_OP("SparseToSparseSetOperation")
|
||||
}
|
||||
// The following should stay in sync with `ComputeSparseToSparse` shape
|
||||
// assertions in kernels/set_kernels.cc.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2)));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5)));
|
||||
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||
c, c->input(0), c->input(1), c->input(2)));
|
||||
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||
c, c->input(3), c->input(4), c->input(5)));
|
||||
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
|
||||
c->set_output(1, c->Vector(c->UnknownDim()));
|
||||
c->set_output(2, c->Vector(c->UnknownDim()));
|
||||
|
@ -860,5 +860,46 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
|
||||
ShapeHandle values_shape, ShapeHandle shape_shape) {
|
||||
// Validate ranks.
|
||||
ShapeHandle unused_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
|
||||
|
||||
// Number of elements in indices and values must match.
|
||||
DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
|
||||
if (c->ValueKnown(num_index_elements_dim)) {
|
||||
DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
|
||||
if (c->ValueKnown(num_values_elements_dim)) {
|
||||
int64 num_index_elements = c->Value(num_index_elements_dim);
|
||||
int64 num_values_elements = c->Value(num_values_elements_dim);
|
||||
if (num_index_elements != num_values_elements) {
|
||||
return errors::InvalidArgument("Number of elements in index (",
|
||||
num_index_elements, ") and values (",
|
||||
num_values_elements, ") do not match.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rank embedded in indices must match shape.
|
||||
DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
|
||||
if (c->ValueKnown(index_rank_dim)) {
|
||||
DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
|
||||
if (c->ValueKnown(shape_rank_dim)) {
|
||||
int64 index_rank = c->Value(index_rank_dim);
|
||||
int32 shape_rank = c->Value(shape_rank_dim);
|
||||
if (index_rank != shape_rank) {
|
||||
return errors::InvalidArgument("Index rank (", index_rank,
|
||||
") and shape rank (", shape_rank,
|
||||
") do not match.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
||||
|
@ -203,6 +203,11 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c);
|
||||
// Tested by ops/math_ops_test.cc.
|
||||
Status BroadcastBinaryOpShapeFn(InferenceContext* c);
|
||||
|
||||
// Validates the 3 component tensors of a sparse tensor have the proper
|
||||
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
|
||||
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
|
||||
ShapeHandle values_shape, ShapeHandle shape_shape);
|
||||
|
||||
} // namespace shape_inference
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -40,6 +40,19 @@ TensorShapeProto Unknown() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||
OpRegistrationData op_reg_data;
|
||||
OpDefBuilder b("dummy");
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
b.Input(strings::StrCat("i", i, ": float"));
|
||||
}
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
b.Output(strings::StrCat("o", i, ": float"));
|
||||
}
|
||||
CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
|
||||
return op_reg_data.op_def;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(CommonShapeFnsTest, NoOutputShapeTest) {
|
||||
@ -840,5 +853,138 @@ TEST(CommonShapeFnsTest, ReduceForReduceJoin_ShapeFn) {
|
||||
INFER_OK(op, "[?,?,?];[2]", "?");
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
|
||||
{}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||
ValidateSparseTensor(&c, indices, values, shape).code());
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||
ValidateSparseTensor(&c, indices, values, shape).code());
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||
ValidateSparseTensor(&c, indices, values, shape).code());
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||
}
|
||||
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
||||
|
@ -424,50 +424,6 @@ class InferenceContext {
|
||||
return output_handle_dtype_[idx];
|
||||
}
|
||||
|
||||
// Validates the 3 component tensors of a sparse tensor have the proper
|
||||
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
|
||||
Status ValidateSparseTensor(ShapeHandle indices_shape,
|
||||
ShapeHandle values_shape,
|
||||
ShapeHandle shape_shape) {
|
||||
// Validate ranks.
|
||||
ShapeHandle unused_shape;
|
||||
TF_RETURN_IF_ERROR(WithRank(indices_shape, 2, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(WithRank(values_shape, 1, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(WithRank(shape_shape, 1, &unused_shape));
|
||||
|
||||
// Number of elements in indices and values must match.
|
||||
DimensionHandle num_index_elements_dim = Dim(indices_shape, 0);
|
||||
if (ValueKnown(num_index_elements_dim)) {
|
||||
DimensionHandle num_values_elements_dim = Dim(values_shape, 0);
|
||||
if (ValueKnown(num_values_elements_dim)) {
|
||||
int64 num_index_elements = Value(num_index_elements_dim);
|
||||
int64 num_values_elements = Value(num_values_elements_dim);
|
||||
if (num_index_elements != num_values_elements) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of elements in index (", num_index_elements,
|
||||
") and values (", num_values_elements, ") do not match.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rank embedded in indices must match shape.
|
||||
DimensionHandle index_rank_dim = Dim(indices_shape, 1);
|
||||
if (ValueKnown(index_rank_dim)) {
|
||||
DimensionHandle shape_rank_dim = Dim(shape_shape, 0);
|
||||
if (ValueKnown(shape_rank_dim)) {
|
||||
int64 index_rank = Value(index_rank_dim);
|
||||
int32 shape_rank = Value(shape_rank_dim);
|
||||
if (index_rank != shape_rank) {
|
||||
return errors::InvalidArgument("Index rank (", index_rank,
|
||||
") and shape rank (", shape_rank,
|
||||
") do not match.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note that shape functions should usually call MakeShapeFromShapeTensor,
|
||||
// as it does more analysis to provide partial shapes.
|
||||
//
|
||||
|
@ -1264,138 +1264,5 @@ TEST_F(ShapeInferenceTest, Max) {
|
||||
EXPECT_TRUE(SameHandle(d_2, out));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
|
||||
{}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||
c.ValidateSparseTensor(indices, values, shape).code());
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||
c.ValidateSparseTensor(indices, values, shape).code());
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||
c.ValidateSparseTensor(indices, values, shape).code());
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
auto indices = c.input(0);
|
||||
auto values = c.input(1);
|
||||
auto shape = c.input(2);
|
||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
||||
}
|
||||
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
||||
|
@ -997,10 +997,10 @@ REGISTER_OP("EditDistance")
|
||||
.Attr("T: type")
|
||||
.Output("output: float")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2)));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5)));
|
||||
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||
c, c->input(0), c->input(1), c->input(2)));
|
||||
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||
c, c->input(3), c->input(4), c->input(5)));
|
||||
const Tensor* hypothesis_shape_t = c->input_tensor(2);
|
||||
const Tensor* truth_shape_t = c->input_tensor(5);
|
||||
if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
|
||||
|
Loading…
Reference in New Issue
Block a user