Move ValidateSparseTensor to common_shape_fns.h.
Change: 139112567
This commit is contained in:
parent
750c98508c
commit
0c20187645
@ -173,8 +173,8 @@ REGISTER_OP("DenseToSparseSetOperation")
|
|||||||
} else {
|
} else {
|
||||||
output_rank = c->UnknownDim();
|
output_rank = c->UnknownDim();
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||||
c->ValidateSparseTensor(c->input(1), c->input(2), c->input(3)));
|
c, c->input(1), c->input(2), c->input(3)));
|
||||||
DimensionHandle output_num_elements = c->Dim(input0_shape, 0);
|
DimensionHandle output_num_elements = c->Dim(input0_shape, 0);
|
||||||
if (!c->ValueKnown(output_num_elements)) {
|
if (!c->ValueKnown(output_num_elements)) {
|
||||||
output_num_elements = c->UnknownDim();
|
output_num_elements = c->UnknownDim();
|
||||||
@ -239,10 +239,10 @@ REGISTER_OP("SparseToSparseSetOperation")
|
|||||||
}
|
}
|
||||||
// The following should stay in sync with `ComputeSparseToSparse` shape
|
// The following should stay in sync with `ComputeSparseToSparse` shape
|
||||||
// assertions in kernels/set_kernels.cc.
|
// assertions in kernels/set_kernels.cc.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||||
c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2)));
|
c, c->input(0), c->input(1), c->input(2)));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||||
c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5)));
|
c, c->input(3), c->input(4), c->input(5)));
|
||||||
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
|
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
|
||||||
c->set_output(1, c->Vector(c->UnknownDim()));
|
c->set_output(1, c->Vector(c->UnknownDim()));
|
||||||
c->set_output(2, c->Vector(c->UnknownDim()));
|
c->set_output(2, c->Vector(c->UnknownDim()));
|
||||||
|
@ -860,5 +860,46 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
|
||||||
|
ShapeHandle values_shape, ShapeHandle shape_shape) {
|
||||||
|
// Validate ranks.
|
||||||
|
ShapeHandle unused_shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
|
||||||
|
|
||||||
|
// Number of elements in indices and values must match.
|
||||||
|
DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
|
||||||
|
if (c->ValueKnown(num_index_elements_dim)) {
|
||||||
|
DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
|
||||||
|
if (c->ValueKnown(num_values_elements_dim)) {
|
||||||
|
int64 num_index_elements = c->Value(num_index_elements_dim);
|
||||||
|
int64 num_values_elements = c->Value(num_values_elements_dim);
|
||||||
|
if (num_index_elements != num_values_elements) {
|
||||||
|
return errors::InvalidArgument("Number of elements in index (",
|
||||||
|
num_index_elements, ") and values (",
|
||||||
|
num_values_elements, ") do not match.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rank embedded in indices must match shape.
|
||||||
|
DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
|
||||||
|
if (c->ValueKnown(index_rank_dim)) {
|
||||||
|
DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
|
||||||
|
if (c->ValueKnown(shape_rank_dim)) {
|
||||||
|
int64 index_rank = c->Value(index_rank_dim);
|
||||||
|
int32 shape_rank = c->Value(shape_rank_dim);
|
||||||
|
if (index_rank != shape_rank) {
|
||||||
|
return errors::InvalidArgument("Index rank (", index_rank,
|
||||||
|
") and shape rank (", shape_rank,
|
||||||
|
") do not match.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace shape_inference
|
} // namespace shape_inference
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -203,6 +203,11 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c);
|
|||||||
// Tested by ops/math_ops_test.cc.
|
// Tested by ops/math_ops_test.cc.
|
||||||
Status BroadcastBinaryOpShapeFn(InferenceContext* c);
|
Status BroadcastBinaryOpShapeFn(InferenceContext* c);
|
||||||
|
|
||||||
|
// Validates the 3 component tensors of a sparse tensor have the proper
|
||||||
|
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
|
||||||
|
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
|
||||||
|
ShapeHandle values_shape, ShapeHandle shape_shape);
|
||||||
|
|
||||||
} // namespace shape_inference
|
} // namespace shape_inference
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -40,6 +40,19 @@ TensorShapeProto Unknown() {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||||
|
OpRegistrationData op_reg_data;
|
||||||
|
OpDefBuilder b("dummy");
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
b.Input(strings::StrCat("i", i, ": float"));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
b.Output(strings::StrCat("o", i, ": float"));
|
||||||
|
}
|
||||||
|
CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
|
||||||
|
return op_reg_data.op_def;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TEST(CommonShapeFnsTest, NoOutputShapeTest) {
|
TEST(CommonShapeFnsTest, NoOutputShapeTest) {
|
||||||
@ -840,5 +853,138 @@ TEST(CommonShapeFnsTest, ReduceForReduceJoin_ShapeFn) {
|
|||||||
INFER_OK(op, "[?,?,?];[2]", "?");
|
INFER_OK(op, "[?,?,?];[2]", "?");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
|
||||||
|
{}, {}, {}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
|
||||||
|
{}, {}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
|
||||||
|
{}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||||
|
ValidateSparseTensor(&c, indices, values, shape).code());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
|
||||||
|
{}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||||
|
ValidateSparseTensor(&c, indices, values, shape).code());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
|
||||||
|
{}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||||
|
ValidateSparseTensor(&c, indices, values, shape).code());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
|
||||||
|
{}, {}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
|
||||||
|
{}, {}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
|
||||||
|
{}, {}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
|
||||||
|
{}, {}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommonShapeFnsTest, ValidateSparseTensor) {
|
||||||
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
|
||||||
|
{}, {});
|
||||||
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
|
EXPECT_EQ(1, c.num_outputs());
|
||||||
|
|
||||||
|
auto indices = c.input(0);
|
||||||
|
auto values = c.input(1);
|
||||||
|
auto shape = c.input(2);
|
||||||
|
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace shape_inference
|
} // namespace shape_inference
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -424,50 +424,6 @@ class InferenceContext {
|
|||||||
return output_handle_dtype_[idx];
|
return output_handle_dtype_[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validates the 3 component tensors of a sparse tensor have the proper
|
|
||||||
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
|
|
||||||
Status ValidateSparseTensor(ShapeHandle indices_shape,
|
|
||||||
ShapeHandle values_shape,
|
|
||||||
ShapeHandle shape_shape) {
|
|
||||||
// Validate ranks.
|
|
||||||
ShapeHandle unused_shape;
|
|
||||||
TF_RETURN_IF_ERROR(WithRank(indices_shape, 2, &unused_shape));
|
|
||||||
TF_RETURN_IF_ERROR(WithRank(values_shape, 1, &unused_shape));
|
|
||||||
TF_RETURN_IF_ERROR(WithRank(shape_shape, 1, &unused_shape));
|
|
||||||
|
|
||||||
// Number of elements in indices and values must match.
|
|
||||||
DimensionHandle num_index_elements_dim = Dim(indices_shape, 0);
|
|
||||||
if (ValueKnown(num_index_elements_dim)) {
|
|
||||||
DimensionHandle num_values_elements_dim = Dim(values_shape, 0);
|
|
||||||
if (ValueKnown(num_values_elements_dim)) {
|
|
||||||
int64 num_index_elements = Value(num_index_elements_dim);
|
|
||||||
int64 num_values_elements = Value(num_values_elements_dim);
|
|
||||||
if (num_index_elements != num_values_elements) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Number of elements in index (", num_index_elements,
|
|
||||||
") and values (", num_values_elements, ") do not match.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rank embedded in indices must match shape.
|
|
||||||
DimensionHandle index_rank_dim = Dim(indices_shape, 1);
|
|
||||||
if (ValueKnown(index_rank_dim)) {
|
|
||||||
DimensionHandle shape_rank_dim = Dim(shape_shape, 0);
|
|
||||||
if (ValueKnown(shape_rank_dim)) {
|
|
||||||
int64 index_rank = Value(index_rank_dim);
|
|
||||||
int32 shape_rank = Value(shape_rank_dim);
|
|
||||||
if (index_rank != shape_rank) {
|
|
||||||
return errors::InvalidArgument("Index rank (", index_rank,
|
|
||||||
") and shape rank (", shape_rank,
|
|
||||||
") do not match.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note that shape functions should usually call MakeShapeFromShapeTensor,
|
// Note that shape functions should usually call MakeShapeFromShapeTensor,
|
||||||
// as it does more analysis to provide partial shapes.
|
// as it does more analysis to provide partial shapes.
|
||||||
//
|
//
|
||||||
|
@ -1264,138 +1264,5 @@ TEST_F(ShapeInferenceTest, Max) {
|
|||||||
EXPECT_TRUE(SameHandle(d_2, out));
|
EXPECT_TRUE(SameHandle(d_2, out));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
|
|
||||||
{}, {}, {}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
|
|
||||||
{}, {}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
|
|
||||||
{}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
|
||||||
c.ValidateSparseTensor(indices, values, shape).code());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
|
|
||||||
{}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
|
||||||
c.ValidateSparseTensor(indices, values, shape).code());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
|
|
||||||
{}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
|
||||||
c.ValidateSparseTensor(indices, values, shape).code());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
|
|
||||||
{}, {}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
|
|
||||||
{}, {}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
|
|
||||||
{}, {}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
|
|
||||||
{}, {}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
|
|
||||||
NodeDef def;
|
|
||||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
|
|
||||||
{}, {});
|
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
|
||||||
EXPECT_EQ(1, c.num_outputs());
|
|
||||||
|
|
||||||
auto indices = c.input(0);
|
|
||||||
auto values = c.input(1);
|
|
||||||
auto shape = c.input(2);
|
|
||||||
TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace shape_inference
|
} // namespace shape_inference
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -997,10 +997,10 @@ REGISTER_OP("EditDistance")
|
|||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
.Output("output: float")
|
.Output("output: float")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||||
c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2)));
|
c, c->input(0), c->input(1), c->input(2)));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||||
c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5)));
|
c, c->input(3), c->input(4), c->input(5)));
|
||||||
const Tensor* hypothesis_shape_t = c->input_tensor(2);
|
const Tensor* hypothesis_shape_t = c->input_tensor(2);
|
||||||
const Tensor* truth_shape_t = c->input_tensor(5);
|
const Tensor* truth_shape_t = c->input_tensor(5);
|
||||||
if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
|
if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
|
||||||
|
Loading…
Reference in New Issue
Block a user