Unifying the scatter_nd* type ops shape inference code. It was duplicated in two places. Also cleaned up the error messages a bit to remove references to inner and outer dimensions of tensors and directly reference which dimensions the error message is referring to. Also unifying eager and graph mode error messages and removing some run_deprecated_v1 annotations as a result

Added some c++ shape inference tests as well.

PiperOrigin-RevId: 326378038
Change-Id: I58ab87ef0c476049da79c5896838a3b92649b772
This commit is contained in:
Rohan Jain 2020-08-12 21:42:55 -07:00 committed by TensorFlower Gardener
parent 7f5fdceaf0
commit a01cf466aa
9 changed files with 180 additions and 164 deletions

View File

@ -2257,66 +2257,57 @@ Status GatherNdShape(InferenceContext* c) {
return Status::OK();
}
Status ScatterNdUpdateShape(InferenceContext* c) {
ShapeHandle input_shape = c->input(0);
if (c->input_handle_shapes_and_types(0) != nullptr) {
// This is called for tf.scatter_nd_update; input is a Variable handle.
const auto& shape_and_type = *(c->input_handle_shapes_and_types(0));
if (shape_and_type.size() == 1) {
input_shape = shape_and_type[0].shape;
}
}
ShapeHandle indices_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
ShapeHandle updates_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle updates_shape,
ShapeHandle input_shape) {
if (c->Value(c->NumElements(input_shape)) == 0 &&
(c->Value(c->NumElements(indices_shape)) > 0 ||
c->Value(c->NumElements(updates_shape)) > 0)) {
return errors::InvalidArgument(
"Indices and updates specified for empty output shape");
"Indices and updates specified for empty input");
}
if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
const int64 num_outer_dims = c->Rank(indices_shape) - 1;
const DimensionHandle index_size = c->Dim(indices_shape, -1);
const int64 outer_dims = c->Rank(indices_shape) - 1;
const DimensionHandle ixdim = c->Dim(indices_shape, -1);
// We can only do more validation if the last dimension of indices
// is a known value.
if (c->ValueKnown(index_size)) {
const int64 ix = c->Value(index_size);
if (c->ValueKnown(ixdim)) {
int64 ix = c->Value(ixdim);
ShapeHandle unused;
ShapeHandle prefix_indices;
TF_RETURN_IF_ERROR(
c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices));
c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
ShapeHandle prefix_updates;
TF_RETURN_IF_ERROR(
c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates));
c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"The outer ", num_outer_dims,
" dimensions of indices.shape=", c->DebugString(indices_shape),
" must match the outer ", num_outer_dims,
" dimensions of updates.shape=", c->DebugString(updates_shape),
": ", s.error_message());
"Dimensions [0,", outer_dims,
") of indices[shape=", c->DebugString(indices_shape),
"] = ", c->DebugString(prefix_indices),
" must match dimensions [0,", outer_dims,
") of updates[shape=", c->DebugString(updates_shape),
"] = ", c->DebugString(prefix_updates), ": ", s.error_message());
}
ShapeHandle input_suffix;
TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix));
ShapeHandle suffix_output;
TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output));
ShapeHandle suffix_updates;
TF_RETURN_IF_ERROR(
c->Subshape(updates_shape, num_outer_dims, &suffix_updates));
s = c->Merge(input_suffix, suffix_updates, &unused);
c->Subshape(updates_shape, outer_dims, &suffix_updates));
s = c->Merge(suffix_output, suffix_updates, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"The inner ", c->Rank(input_shape) - ix,
" dimensions of input.shape=", c->DebugString(input_shape),
" must match the inner ", c->Rank(updates_shape) - num_outer_dims,
" dimensions of updates.shape=", c->DebugString(updates_shape),
": ", s.error_message());
"Dimensions [", ix, ",", c->Rank(input_shape),
") of input[shape=", c->DebugString(input_shape),
"] = ", c->DebugString(suffix_output), " must match dimensions [",
outer_dims, ",", c->Rank(updates_shape),
") of updates[shape=", c->DebugString(updates_shape),
"] = ", c->DebugString(suffix_updates), ": ", s.error_message());
}
}
}

View File

@ -241,8 +241,9 @@ Status ValidateVariableResourceHandle(
// Shape function for GatherNd operations.
Status GatherNdShape(InferenceContext* c);
// Shape function for ScatterNd update/add/sub/... operations.
Status ScatterNdUpdateShape(InferenceContext* c);
// Helper shape function for ScatterNd.../TensorScatter... operations.
Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle updates_shape, ShapeHandle input_shape);
// Shape function for ops with an explicit "shape" attribute.
Status ExplicitShape(InferenceContext* c);

View File

@ -100,29 +100,31 @@ class ScatterNdOp : public OpKernel {
const int64 outer_dims = indices.shape().dims() - 1;
for (int i = 0; i < outer_dims; ++i) {
OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
errors::InvalidArgument(
"Outer dimensions of indices and update must match. "
"Indices shape: ",
indices.shape().DebugString(),
", updates shape:", updates.shape().DebugString()));
OP_REQUIRES(
c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
errors::InvalidArgument(
"Dimensions [0,", outer_dims,
") of indices[shape=", indices.shape().DebugString(),
"] must match dimensions [0,", outer_dims,
") of updates[shape=", updates.shape().DebugString(), "]"));
}
const int64 ix = indices.shape().dim_size(outer_dims);
OP_REQUIRES(
c, updates.shape().dims() - outer_dims == shape.dims() - ix,
errors::InvalidArgument("Inner dimensions of output shape must match "
"inner dimensions of updates shape. Output: ",
shape.DebugString(),
" updates: ", updates.shape().DebugString()));
OP_REQUIRES(c, updates.shape().dims() - outer_dims == shape.dims() - ix,
errors::InvalidArgument(
"Dimensions [", ix, ",", shape.dims(), ") of input[shape=",
shape.DebugString(), "] must match dimensions [",
outer_dims, ",", updates.shape().dims(),
") of updates[shape=", updates.shape().DebugString(), "]"));
for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
OP_REQUIRES(
c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
errors::InvalidArgument(
"The inner ", shape.dims() - ix,
" dimensions of output.shape=", shape.DebugString(),
" must match the inner ", updates.shape().dims() - outer_dims,
" dimensions of updates.shape=", updates.shape().DebugString()));
errors::InvalidArgument("Dimensions [", ix, ",", shape.dims(),
") of input[shape=", shape.DebugString(),
"] must match dimensions [", outer_dims, ",",
updates.shape().dims(), ") of updates[shape=",
updates.shape().DebugString(), "]"));
}
OP_REQUIRES(c, shape_input.dims() == 1,
errors::InvalidArgument("Shape must be a vector"));
@ -602,30 +604,35 @@ Status ValidateUpdateShape(const TensorShape& params_shape,
(indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1;
const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1;
auto shape_err = [&]() {
auto shape_err_prefix = [&]() {
return errors::InvalidArgument(
"Must have updates.shape = indices.shape[:batch_dim] + ",
"params_shape[slice_dim:], got updates.shape: ",
updates.shape().DebugString(),
", indices.shape: ", indices.shape().DebugString(),
", params_shape: ", params_shape.DebugString(),
", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim);
"Dimensions [0,", batch_dim,
") of indices[shape=", indices.shape().DebugString(),
"] must match dimensions [0,", batch_dim,
") of updates[shape=", updates.shape().DebugString(), "]");
};
auto shape_err_suffix = [&]() {
return errors::InvalidArgument(
"Dimensions [", slice_dim, ",", params_shape.dims(),
") of input[shape=", params_shape.DebugString(),
"] must match dimensions [", slice_dim, ",", updates.dims(),
") of updates[shape=", updates.shape().DebugString(), "]");
};
if (updates.dims() < batch_dim) return shape_err();
if (updates.dims() < batch_dim) return shape_err_prefix();
if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) {
return shape_err();
return shape_err_suffix();
}
if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) {
return shape_err();
return shape_err_suffix();
}
for (int d = 0; d < batch_dim; ++d) {
if (updates.dim_size(d) != indices.dim_size(d)) return shape_err();
if (updates.dim_size(d) != indices.dim_size(d)) return shape_err_prefix();
}
for (int d = 0; d < updates.dims() - batch_dim; ++d) {
if (updates.dim_size(d + batch_dim) !=
params_shape.dim_size(d + slice_dim)) {
return shape_err();
return shape_err_suffix();
}
}
return Status::OK();
@ -654,9 +661,9 @@ Status PrepareAndValidateInputs(const TensorShape& params_shape,
if (updates.dim_size(0) != indices.dim_size(0)) {
return errors::InvalidArgument(
"The outermost dimension of updates and indices ",
"must match. Got indices.shape ", indices_shape.DebugString(),
", updates.shape ", updates_shape.DebugString());
"Dimensions [0,1) of indices[shape=", indices_shape.DebugString(),
"] = ", indices.dim_size(0), " must match dimensions [0,1) of updates[",
"shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0));
}
TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates));

View File

@ -200,8 +200,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) {
Status s = RunOpKernel();
EXPECT_TRUE(absl::StrContains(
s.ToString(),
"The outermost dimension of updates and indices must match. Got "
"indices.shape [1,3,1], updates.shape [3,3]"))
"Dimensions [0,1) of indices[shape=[1,3,1]] = 1 must match dimensions "
"[0,1) of updates[shape=[3,3]] = 3"))
<< s;
}
@ -217,7 +217,9 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
{100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
Status s = RunOpKernel();
EXPECT_TRUE(absl::StrContains(
s.ToString(), "Must have updates.shape = indices.shape[:batch_dim]"))
s.ToString(),
"Dimensions [1,2) of input[shape=[5,3]] must match dimensions [1,2) of "
"updates[shape=[3,4]]"))
<< s;
}
@ -233,7 +235,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
Status s = RunOpKernel();
EXPECT_TRUE(absl::StrContains(
s.ToString(),
"The outermost dimension of updates and indices must match."))
"Dimensions [0,1) of indices[shape=[3,1]] = 3 must match dimensions [0,1)"
" of updates[shape=[2,3]] = 2"))
<< s;
}

View File

@ -2974,73 +2974,6 @@ REGISTER_OP("QuantizedInstanceNorm")
namespace {
Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle updates_shape,
ShapeHandle output_shape) {
if (c->Value(c->NumElements(output_shape)) == 0 &&
(c->Value(c->NumElements(indices_shape)) > 0 ||
c->Value(c->NumElements(updates_shape)) > 0)) {
return errors::InvalidArgument(
"Indices and updates specified for empty output shape");
}
if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
const int64 outer_dims = c->Rank(indices_shape) - 1;
const DimensionHandle ixdim = c->Dim(indices_shape, -1);
// We can only do more validation if the last dimension of indices
// is a known value.
if (c->ValueKnown(ixdim)) {
int64 ix = c->Value(ixdim);
ShapeHandle unused;
ShapeHandle prefix_indices;
TF_RETURN_IF_ERROR(
c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
ShapeHandle prefix_updates;
TF_RETURN_IF_ERROR(
c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"The outer ", outer_dims,
" dimensions of indices.shape=", c->DebugString(indices_shape),
" must match the outer ", outer_dims,
" dimensions of updates.shape=", c->DebugString(updates_shape),
": ", s.error_message());
}
ShapeHandle suffix_output;
TF_RETURN_IF_ERROR(c->Subshape(output_shape, ix, &suffix_output));
ShapeHandle suffix_updates;
TF_RETURN_IF_ERROR(
c->Subshape(updates_shape, outer_dims, &suffix_updates));
s = c->Merge(suffix_output, suffix_updates, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"The inner ", c->Rank(output_shape) - ix,
" dimensions of output.shape=", c->DebugString(output_shape),
" must match the inner ", c->Rank(updates_shape) - outer_dims,
" dimensions of updates.shape=", c->DebugString(updates_shape),
": ", s.error_message());
}
}
}
c->set_output(0, output_shape);
return Status::OK();
}
Status ScatterNdShape(InferenceContext* c) {
ShapeHandle indices_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
ShapeHandle updates_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
}
Status ScatterNdTensorShape(InferenceContext* c) {
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape));
@ -3048,7 +2981,8 @@ Status ScatterNdTensorShape(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
ShapeHandle updates_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape,
output_shape);
}
} // namespace
@ -3088,7 +3022,16 @@ REGISTER_OP("ScatterNd")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(ScatterNdShape);
.SetShapeFn([](InferenceContext* c) {
ShapeHandle indices_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
ShapeHandle updates_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
return shape_inference::ScatterNdShapeHelper(c, indices_shape,
updates_shape, output_shape);
});
REGISTER_OP("TensorScatterUpdate")
.Input("tensor: T")
@ -3142,7 +3085,7 @@ REGISTER_OP("ScatterNdNonAliasingAdd")
.Output("output: T")
.Attr("T: {numbertype, bool}")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("FakeQuantWithMinMaxArgs")
.Attr("min: float = -6.0")

View File

@ -27,6 +27,39 @@ limitations under the License.
namespace tensorflow {
TEST(ArrayOpsTest, TensorScatterUpdate_ShapeFn) {
ShapeInferenceTestOp op("TensorScatterUpdate");
INFER_OK(op, "[4,3];[8,2];[8]", "in0");
INFER_OK(op, "[?,?];[?,2];[?]", "in0");
INFER_OK(op, "[?];[?];[?]", "in0");
INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
"[];[?,2];[?]");
INFER_ERROR("Indices and updates specified for empty input", op,
"[0,2,2];[8,2];[8]");
INFER_ERROR(
"Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
"dimensions [0,1) of updates[shape=[9]] = [9]",
op, "[?,?];[8,2];[9]");
INFER_ERROR(
"Dimensions [2,2) of input[shape=[?,?]] = [] must match "
"dimensions [1,2) of updates[shape=[?,1]] = [1]",
op, "[?,?];[?,2];[?,1]");
}
TEST(ArrayOpsTest, ScatterNd_ShapeFn) {
ShapeInferenceTestOp op("ScatterNd");
INFER_OK(op, "[8,2];[8];[2]", "[?,?]");
INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[?,2];[?];[]");
INFER_ERROR(
"Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
"dimensions [0,1) of updates[shape=[9]] = [9]",
op, "[8,2];[9];[?]");
}
TEST(ArrayOpsTest, UnravelIndex_ShapeFn) {
ShapeInferenceTestOp op("UnravelIndex");

View File

@ -131,6 +131,22 @@ Status ScatterUpdateShape(InferenceContext* c) {
return Status::OK();
}
Status ScatterNdUpdateShape(InferenceContext* c) {
ShapeHandle input_shape = c->input(0);
if (c->input_handle_shapes_and_types(0) != nullptr) {
const auto& shape_and_type = *(c->input_handle_shapes_and_types(0));
if (!shape_and_type.empty()) {
input_shape = shape_and_type[0].shape;
}
}
ShapeHandle indices_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
ShapeHandle updates_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape,
input_shape);
}
} // namespace
REGISTER_OP("ScatterUpdate")
@ -211,7 +227,7 @@ REGISTER_OP("ScatterNdUpdate")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdUpdate")
.Input("ref: resource")
@ -220,7 +236,7 @@ REGISTER_OP("ResourceScatterNdUpdate")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdAdd")
.Input("ref: resource")
@ -229,7 +245,7 @@ REGISTER_OP("ResourceScatterNdAdd")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdSub")
.Input("ref: resource")
@ -238,7 +254,7 @@ REGISTER_OP("ResourceScatterNdSub")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdMin")
.Input("ref: resource")
@ -247,7 +263,7 @@ REGISTER_OP("ResourceScatterNdMin")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdMax")
.Input("ref: resource")
@ -256,7 +272,7 @@ REGISTER_OP("ResourceScatterNdMax")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ScatterNdAdd")
.Input("ref: Ref(T)")
@ -266,7 +282,7 @@ REGISTER_OP("ScatterNdAdd")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ScatterNdSub")
.Input("ref: Ref(T)")
@ -276,7 +292,7 @@ REGISTER_OP("ScatterNdSub")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ScatterNdMax")
.Input("ref: Ref(T)")
@ -286,7 +302,7 @@ REGISTER_OP("ScatterNdMax")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ScatterNdMin")
.Input("ref: Ref(T)")
@ -296,7 +312,7 @@ REGISTER_OP("ScatterNdMin")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("CountUpTo")
.Input("ref: Ref(T)")

View File

@ -69,6 +69,28 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) {
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[2];[];[2]");
}
TEST(StateOpsTest, ResourceScatterNdUpdate_ShapeFn) {
ShapeInferenceTestOp op("ResourceScatterNdUpdate");
TF_ASSERT_OK(NodeDefBuilder("test", "ResourceScatterNdUpdate")
.Input("ref", 0, DT_RESOURCE)
.Input("indices", 0, DT_INT32)
.Input("updates", 1, DT_FLOAT)
.Finalize(&op.node_def));
std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types;
op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types);
op.input_resource_handle_shapes_and_types.push_back(nullptr);
op.input_resource_handle_shapes_and_types.push_back(nullptr);
shapes_and_types.emplace_back("[?,?]", DT_FLOAT);
INFER_OK(op, "[?];[?,2];[?]", "");
INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
"[?];[?,2];[]");
INFER_ERROR(
"Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
"dimensions [0,1) of updates[shape=[9]] = [9]",
op, "[?];[8,2];[9]");
}
TEST(StateOpsTest, TemporaryVariable_ShapeFn) {
ShapeInferenceTestOp op("TemporaryVariable");
TensorShape shape({1, 2, 3});

View File

@ -331,24 +331,24 @@ class StatefulScatterNdTest(test.TestCase):
self.evaluate(ref.initializer)
self.assertAllEqual(expected_result, self.evaluate(scatter_update))
@test_util.run_deprecated_v1
def testRank3InvalidShape1(self):
indices = array_ops.zeros([3, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
with self.assertRaisesWithPredicateMatch(
ValueError, r"The outer \d+ dimensions of indices\.shape="):
(errors.InvalidArgumentError, ValueError),
r"Dimensions \[\d,\d\) of indices\[shape="):
state_ops.scatter_nd_update(ref, indices, updates)
@test_util.run_deprecated_v1
def testRank3InvalidShape2(self):
indices = array_ops.zeros([2, 2, 1], dtypes.int32)
updates = array_ops.zeros([2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
with self.assertRaisesWithPredicateMatch(
ValueError, r"The inner \d+ dimensions of input\.shape="):
(errors.InvalidArgumentError, ValueError),
r"Dimensions \[\d,\d\) of input\[shape="):
state_ops.scatter_nd_update(ref, indices, updates)
def testConcurrentUpdates(self):
@ -511,14 +511,14 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase):
shape = array_ops.placeholder(dtypes.int32, shape=[None])
self.scatter_nd(indices, updates, shape)
@test_util.run_deprecated_v1
def testEmptyOutputShape1(self):
indices = array_ops.zeros([2, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = constant_op.constant([0, 3, 2], dtypes.int32)
with self.assertRaisesWithPredicateMatch(
ValueError, "Indices and updates specified for empty output shape"):
(errors.InvalidArgumentError, ValueError),
"Indices and updates specified for empty"):
self.scatter_nd(indices, updates, shape)
def testEmptyOutputShape2(self):
@ -529,7 +529,7 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase):
with self.cached_session():
with self.assertRaisesOpError(
"Indices and updates specified for empty output"):
"Indices and updates specified for empty (input|output)"):
self.scatter_nd(indices, updates, shape).eval(
feed_dict={
indices: np.zeros([2, 2, 2], dtype=np.int32),
@ -545,22 +545,22 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase):
with self.cached_session():
self.assertEqual(self.evaluate(scatter).size, 0)
@test_util.run_deprecated_v1
def testRank3InvalidShape1(self):
indices = array_ops.zeros([3, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
with self.assertRaisesWithPredicateMatch(
ValueError, r"The outer \d+ dimensions of indices\.shape="):
(errors.InvalidArgumentError, ValueError),
r"Dimensions \[\d\,\d\) of indices\[shape="):
self.scatter_nd(indices, updates, shape)
@test_util.run_deprecated_v1
def testRank3InvalidShape2(self):
indices = array_ops.zeros([2, 2, 1], dtypes.int32)
updates = array_ops.zeros([2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
with self.assertRaisesWithPredicateMatch(
ValueError, r"The inner \d+ dimensions of (input|output)\.shape="):
(errors.InvalidArgumentError, ValueError),
r"Dimensions \[\d\,\d\) of input\[shape="):
self.scatter_nd(indices, updates, shape)
@parameterized.parameters(set((True, context.executing_eagerly())))