diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 36ae36e7b74..8157f4ee01d 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -2257,66 +2257,57 @@ Status GatherNdShape(InferenceContext* c) { return Status::OK(); } -Status ScatterNdUpdateShape(InferenceContext* c) { - ShapeHandle input_shape = c->input(0); - if (c->input_handle_shapes_and_types(0) != nullptr) { - // This is called for tf.scatter_nd_update; input is a Variable handle. - const auto& shape_and_type = *(c->input_handle_shapes_and_types(0)); - if (shape_and_type.size() == 1) { - input_shape = shape_and_type[0].shape; - } - } - ShapeHandle indices_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); - ShapeHandle updates_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); - +Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, + ShapeHandle updates_shape, + ShapeHandle input_shape) { if (c->Value(c->NumElements(input_shape)) == 0 && (c->Value(c->NumElements(indices_shape)) > 0 || c->Value(c->NumElements(updates_shape)) > 0)) { return errors::InvalidArgument( - "Indices and updates specified for empty output shape"); + "Indices and updates specified for empty input"); } if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { - const int64 num_outer_dims = c->Rank(indices_shape) - 1; - const DimensionHandle index_size = c->Dim(indices_shape, -1); + const int64 outer_dims = c->Rank(indices_shape) - 1; + const DimensionHandle ixdim = c->Dim(indices_shape, -1); // We can only do more validation if the last dimension of indices // is a known value. - if (c->ValueKnown(index_size)) { - const int64 ix = c->Value(index_size); + if (c->ValueKnown(ixdim)) { + int64 ix = c->Value(ixdim); ShapeHandle unused; ShapeHandle prefix_indices; TF_RETURN_IF_ERROR( - c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices)); + c->Subshape(indices_shape, 0, outer_dims, &prefix_indices)); ShapeHandle prefix_updates; TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates)); + c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); Status s = c->Merge(prefix_indices, prefix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( - "The outer ", num_outer_dims, - " dimensions of indices.shape=", c->DebugString(indices_shape), - " must match the outer ", num_outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); + "Dimensions [0,", outer_dims, + ") of indices[shape=", c->DebugString(indices_shape), + "] = ", c->DebugString(prefix_indices), + " must match dimensions [0,", outer_dims, + ") of updates[shape=", c->DebugString(updates_shape), + "] = ", c->DebugString(prefix_updates), ": ", s.error_message()); } - ShapeHandle input_suffix; - TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix)); + ShapeHandle suffix_output; + TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output)); ShapeHandle suffix_updates; TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, num_outer_dims, &suffix_updates)); - s = c->Merge(input_suffix, suffix_updates, &unused); + c->Subshape(updates_shape, outer_dims, &suffix_updates)); + s = c->Merge(suffix_output, suffix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( - "The inner ", c->Rank(input_shape) - ix, - " dimensions of input.shape=", c->DebugString(input_shape), - " must match the inner ", c->Rank(updates_shape) - num_outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); + "Dimensions [", ix, ",", c->Rank(input_shape), + ") of input[shape=", c->DebugString(input_shape), + "] = ", c->DebugString(suffix_output), " must match dimensions [", + outer_dims, ",", c->Rank(updates_shape), + ") of updates[shape=", c->DebugString(updates_shape), + "] = ", c->DebugString(suffix_updates), ": ", s.error_message()); } } } diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 218400c2435..f3e02638f54 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -241,8 +241,9 @@ Status ValidateVariableResourceHandle( // Shape function for GatherNd operations. Status GatherNdShape(InferenceContext* c); -// Shape function for ScatterNd update/add/sub/... operations. -Status ScatterNdUpdateShape(InferenceContext* c); +// Helper shape function for ScatterNd.../TensorScatter... operations. +Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, + ShapeHandle updates_shape, ShapeHandle input_shape); // Shape function for ops with an explicit "shape" attribute. Status ExplicitShape(InferenceContext* c); diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 88bf16d974e..942740b9af3 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -100,29 +100,31 @@ class ScatterNdOp : public OpKernel { const int64 outer_dims = indices.shape().dims() - 1; for (int i = 0; i < outer_dims; ++i) { - OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i), - errors::InvalidArgument( - "Outer dimensions of indices and update must match. " - "Indices shape: ", - indices.shape().DebugString(), - ", updates shape:", updates.shape().DebugString())); + OP_REQUIRES( + c, indices.shape().dim_size(i) == updates.shape().dim_size(i), + errors::InvalidArgument( + "Dimensions [0,", outer_dims, + ") of indices[shape=", indices.shape().DebugString(), + "] must match dimensions [0,", outer_dims, + ") of updates[shape=", updates.shape().DebugString(), "]")); } const int64 ix = indices.shape().dim_size(outer_dims); - OP_REQUIRES( - c, updates.shape().dims() - outer_dims == shape.dims() - ix, - errors::InvalidArgument("Inner dimensions of output shape must match " - "inner dimensions of updates shape. Output: ", - shape.DebugString(), - " updates: ", updates.shape().DebugString())); + OP_REQUIRES(c, updates.shape().dims() - outer_dims == shape.dims() - ix, + errors::InvalidArgument( + "Dimensions [", ix, ",", shape.dims(), ") of input[shape=", + shape.DebugString(), "] must match dimensions [", + outer_dims, ",", updates.shape().dims(), + ") of updates[shape=", updates.shape().DebugString(), "]")); + for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) { OP_REQUIRES( c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i), - errors::InvalidArgument( - "The inner ", shape.dims() - ix, - " dimensions of output.shape=", shape.DebugString(), - " must match the inner ", updates.shape().dims() - outer_dims, - " dimensions of updates.shape=", updates.shape().DebugString())); + errors::InvalidArgument("Dimensions [", ix, ",", shape.dims(), + ") of input[shape=", shape.DebugString(), + "] must match dimensions [", outer_dims, ",", + updates.shape().dims(), ") of updates[shape=", + updates.shape().DebugString(), "]")); } OP_REQUIRES(c, shape_input.dims() == 1, errors::InvalidArgument("Shape must be a vector")); @@ -602,30 +604,35 @@ Status ValidateUpdateShape(const TensorShape& params_shape, (indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1; const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1; - auto shape_err = [&]() { + auto shape_err_prefix = [&]() { return errors::InvalidArgument( - "Must have updates.shape = indices.shape[:batch_dim] + ", - "params_shape[slice_dim:], got updates.shape: ", - updates.shape().DebugString(), - ", indices.shape: ", indices.shape().DebugString(), - ", params_shape: ", params_shape.DebugString(), - ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim); + "Dimensions [0,", batch_dim, + ") of indices[shape=", indices.shape().DebugString(), + "] must match dimensions [0,", batch_dim, + ") of updates[shape=", updates.shape().DebugString(), "]"); + }; + auto shape_err_suffix = [&]() { + return errors::InvalidArgument( + "Dimensions [", slice_dim, ",", params_shape.dims(), + ") of input[shape=", params_shape.DebugString(), + "] must match dimensions [", slice_dim, ",", updates.dims(), + ") of updates[shape=", updates.shape().DebugString(), "]"); }; - if (updates.dims() < batch_dim) return shape_err(); + if (updates.dims() < batch_dim) return shape_err_prefix(); if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) { - return shape_err(); + return shape_err_suffix(); } if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) { - return shape_err(); + return shape_err_suffix(); } for (int d = 0; d < batch_dim; ++d) { - if (updates.dim_size(d) != indices.dim_size(d)) return shape_err(); + if (updates.dim_size(d) != indices.dim_size(d)) return shape_err_prefix(); } for (int d = 0; d < updates.dims() - batch_dim; ++d) { if (updates.dim_size(d + batch_dim) != params_shape.dim_size(d + slice_dim)) { - return shape_err(); + return shape_err_suffix(); } } return Status::OK(); @@ -654,9 +661,9 @@ Status PrepareAndValidateInputs(const TensorShape& params_shape, if (updates.dim_size(0) != indices.dim_size(0)) { return errors::InvalidArgument( - "The outermost dimension of updates and indices ", - "must match. Got indices.shape ", indices_shape.DebugString(), - ", updates.shape ", updates_shape.DebugString()); + "Dimensions [0,1) of indices[shape=", indices_shape.DebugString(), + "] = ", indices.dim_size(0), " must match dimensions [0,1) of updates[", + "shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0)); } TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates)); diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc index 1461831a1fb..9c31bed784f 100644 --- a/tensorflow/core/kernels/scatter_nd_op_test.cc +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -200,8 +200,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) { Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), - "The outermost dimension of updates and indices must match. Got " - "indices.shape [1,3,1], updates.shape [3,3]")) + "Dimensions [0,1) of indices[shape=[1,3,1]] = 1 must match dimensions " + "[0,1) of updates[shape=[3,3]] = 3")) << s; } @@ -217,7 +217,9 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004}); Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( - s.ToString(), "Must have updates.shape = indices.shape[:batch_dim]")) + s.ToString(), + "Dimensions [1,2) of input[shape=[5,3]] must match dimensions [1,2) of " + "updates[shape=[3,4]]")) << s; } @@ -233,7 +235,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), - "The outermost dimension of updates and indices must match.")) + "Dimensions [0,1) of indices[shape=[3,1]] = 3 must match dimensions [0,1)" + " of updates[shape=[2,3]] = 2")) << s; } diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 11bfb9a3346..b4dfe6187d5 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2974,73 +2974,6 @@ REGISTER_OP("QuantizedInstanceNorm") namespace { -Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, - ShapeHandle updates_shape, - ShapeHandle output_shape) { - if (c->Value(c->NumElements(output_shape)) == 0 && - (c->Value(c->NumElements(indices_shape)) > 0 || - c->Value(c->NumElements(updates_shape)) > 0)) { - return errors::InvalidArgument( - "Indices and updates specified for empty output shape"); - } - - if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { - const int64 outer_dims = c->Rank(indices_shape) - 1; - const DimensionHandle ixdim = c->Dim(indices_shape, -1); - - // We can only do more validation if the last dimension of indices - // is a known value. - if (c->ValueKnown(ixdim)) { - int64 ix = c->Value(ixdim); - ShapeHandle unused; - ShapeHandle prefix_indices; - TF_RETURN_IF_ERROR( - c->Subshape(indices_shape, 0, outer_dims, &prefix_indices)); - ShapeHandle prefix_updates; - TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); - - Status s = c->Merge(prefix_indices, prefix_updates, &unused); - if (!s.ok()) { - return errors::InvalidArgument( - "The outer ", outer_dims, - " dimensions of indices.shape=", c->DebugString(indices_shape), - " must match the outer ", outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); - } - - ShapeHandle suffix_output; - TF_RETURN_IF_ERROR(c->Subshape(output_shape, ix, &suffix_output)); - ShapeHandle suffix_updates; - TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, outer_dims, &suffix_updates)); - s = c->Merge(suffix_output, suffix_updates, &unused); - if (!s.ok()) { - return errors::InvalidArgument( - "The inner ", c->Rank(output_shape) - ix, - " dimensions of output.shape=", c->DebugString(output_shape), - " must match the inner ", c->Rank(updates_shape) - outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); - } - } - } - - c->set_output(0, output_shape); - return Status::OK(); -} - -Status ScatterNdShape(InferenceContext* c) { - ShapeHandle indices_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape)); - ShapeHandle updates_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); - ShapeHandle output_shape; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape)); - return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape); -} - Status ScatterNdTensorShape(InferenceContext* c) { ShapeHandle output_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape)); @@ -3048,7 +2981,8 @@ Status ScatterNdTensorShape(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); ShapeHandle updates_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); - return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape); + return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape, + output_shape); } } // namespace @@ -3088,7 +3022,16 @@ REGISTER_OP("ScatterNd") .Output("output: T") .Attr("T: type") .Attr("Tindices: {int32, int64}") - .SetShapeFn(ScatterNdShape); + .SetShapeFn([](InferenceContext* c) { + ShapeHandle indices_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape)); + ShapeHandle updates_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); + ShapeHandle output_shape; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape)); + return shape_inference::ScatterNdShapeHelper(c, indices_shape, + updates_shape, output_shape); + }); REGISTER_OP("TensorScatterUpdate") .Input("tensor: T") @@ -3142,7 +3085,7 @@ REGISTER_OP("ScatterNdNonAliasingAdd") .Output("output: T") .Attr("T: {numbertype, bool}") .Attr("Tindices: {int32, int64}") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdTensorShape); REGISTER_OP("FakeQuantWithMinMaxArgs") .Attr("min: float = -6.0") diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 1725bdbac39..412c926d386 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -27,6 +27,39 @@ limitations under the License. namespace tensorflow { +TEST(ArrayOpsTest, TensorScatterUpdate_ShapeFn) { + ShapeInferenceTestOp op("TensorScatterUpdate"); + + INFER_OK(op, "[4,3];[8,2];[8]", "in0"); + INFER_OK(op, "[?,?];[?,2];[?]", "in0"); + INFER_OK(op, "[?];[?];[?]", "in0"); + + INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, + "[];[?,2];[?]"); + INFER_ERROR("Indices and updates specified for empty input", op, + "[0,2,2];[8,2];[8]"); + INFER_ERROR( + "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match " + "dimensions [0,1) of updates[shape=[9]] = [9]", + op, "[?,?];[8,2];[9]"); + INFER_ERROR( + "Dimensions [2,2) of input[shape=[?,?]] = [] must match " + "dimensions [1,2) of updates[shape=[?,1]] = [1]", + op, "[?,?];[?,2];[?,1]"); +} + +TEST(ArrayOpsTest, ScatterNd_ShapeFn) { + ShapeInferenceTestOp op("ScatterNd"); + + INFER_OK(op, "[8,2];[8];[2]", "[?,?]"); + + INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[?,2];[?];[]"); + INFER_ERROR( + "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match " + "dimensions [0,1) of updates[shape=[9]] = [9]", + op, "[8,2];[9];[?]"); +} + TEST(ArrayOpsTest, UnravelIndex_ShapeFn) { ShapeInferenceTestOp op("UnravelIndex"); diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 500d5ec88b8..5d856396360 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -131,6 +131,22 @@ Status ScatterUpdateShape(InferenceContext* c) { return Status::OK(); } +Status ScatterNdUpdateShape(InferenceContext* c) { + ShapeHandle input_shape = c->input(0); + if (c->input_handle_shapes_and_types(0) != nullptr) { + const auto& shape_and_type = *(c->input_handle_shapes_and_types(0)); + if (!shape_and_type.empty()) { + input_shape = shape_and_type[0].shape; + } + } + ShapeHandle indices_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); + ShapeHandle updates_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); + return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape, + input_shape); +} + } // namespace REGISTER_OP("ScatterUpdate") @@ -211,7 +227,7 @@ REGISTER_OP("ScatterNdUpdate") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdUpdate") .Input("ref: resource") @@ -220,7 +236,7 @@ REGISTER_OP("ResourceScatterNdUpdate") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdAdd") .Input("ref: resource") @@ -229,7 +245,7 @@ REGISTER_OP("ResourceScatterNdAdd") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdSub") .Input("ref: resource") @@ -238,7 +254,7 @@ REGISTER_OP("ResourceScatterNdSub") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdMin") .Input("ref: resource") @@ -247,7 +263,7 @@ REGISTER_OP("ResourceScatterNdMin") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdMax") .Input("ref: resource") @@ -256,7 +272,7 @@ REGISTER_OP("ResourceScatterNdMax") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ScatterNdAdd") .Input("ref: Ref(T)") @@ -266,7 +282,7 @@ REGISTER_OP("ScatterNdAdd") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ScatterNdSub") .Input("ref: Ref(T)") @@ -276,7 +292,7 @@ REGISTER_OP("ScatterNdSub") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ScatterNdMax") .Input("ref: Ref(T)") @@ -286,7 +302,7 @@ REGISTER_OP("ScatterNdMax") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ScatterNdMin") .Input("ref: Ref(T)") @@ -296,7 +312,7 @@ REGISTER_OP("ScatterNdMin") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("CountUpTo") .Input("ref: Ref(T)") diff --git a/tensorflow/core/ops/state_ops_test.cc b/tensorflow/core/ops/state_ops_test.cc index a0caad4a49f..bc68cf46f03 100644 --- a/tensorflow/core/ops/state_ops_test.cc +++ b/tensorflow/core/ops/state_ops_test.cc @@ -69,6 +69,28 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) { INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[2];[];[2]"); } +TEST(StateOpsTest, ResourceScatterNdUpdate_ShapeFn) { + ShapeInferenceTestOp op("ResourceScatterNdUpdate"); + TF_ASSERT_OK(NodeDefBuilder("test", "ResourceScatterNdUpdate") + .Input("ref", 0, DT_RESOURCE) + .Input("indices", 0, DT_INT32) + .Input("updates", 1, DT_FLOAT) + .Finalize(&op.node_def)); + + std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types; + op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types); + op.input_resource_handle_shapes_and_types.push_back(nullptr); + op.input_resource_handle_shapes_and_types.push_back(nullptr); + shapes_and_types.emplace_back("[?,?]", DT_FLOAT); + INFER_OK(op, "[?];[?,2];[?]", ""); + INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, + "[?];[?,2];[]"); + INFER_ERROR( + "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match " + "dimensions [0,1) of updates[shape=[9]] = [9]", + op, "[?];[8,2];[9]"); +} + TEST(StateOpsTest, TemporaryVariable_ShapeFn) { ShapeInferenceTestOp op("TemporaryVariable"); TensorShape shape({1, 2, 3}); diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index c5e5e549ee7..d5843c1a766 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -331,24 +331,24 @@ class StatefulScatterNdTest(test.TestCase): self.evaluate(ref.initializer) self.assertAllEqual(expected_result, self.evaluate(scatter_update)) - @test_util.run_deprecated_v1 def testRank3InvalidShape1(self): indices = array_ops.zeros([3, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = np.array([2, 2, 2]) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) with self.assertRaisesWithPredicateMatch( - ValueError, r"The outer \d+ dimensions of indices\.shape="): + (errors.InvalidArgumentError, ValueError), + r"Dimensions \[\d,\d\) of indices\[shape="): state_ops.scatter_nd_update(ref, indices, updates) - @test_util.run_deprecated_v1 def testRank3InvalidShape2(self): indices = array_ops.zeros([2, 2, 1], dtypes.int32) updates = array_ops.zeros([2, 2], dtypes.int32) shape = np.array([2, 2, 2]) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) with self.assertRaisesWithPredicateMatch( - ValueError, r"The inner \d+ dimensions of input\.shape="): + (errors.InvalidArgumentError, ValueError), + r"Dimensions \[\d,\d\) of input\[shape="): state_ops.scatter_nd_update(ref, indices, updates) def testConcurrentUpdates(self): @@ -511,14 +511,14 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase): shape = array_ops.placeholder(dtypes.int32, shape=[None]) self.scatter_nd(indices, updates, shape) - @test_util.run_deprecated_v1 def testEmptyOutputShape1(self): indices = array_ops.zeros([2, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = constant_op.constant([0, 3, 2], dtypes.int32) with self.assertRaisesWithPredicateMatch( - ValueError, "Indices and updates specified for empty output shape"): + (errors.InvalidArgumentError, ValueError), + "Indices and updates specified for empty"): self.scatter_nd(indices, updates, shape) def testEmptyOutputShape2(self): @@ -529,7 +529,7 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase): with self.cached_session(): with self.assertRaisesOpError( - "Indices and updates specified for empty output"): + "Indices and updates specified for empty (input|output)"): self.scatter_nd(indices, updates, shape).eval( feed_dict={ indices: np.zeros([2, 2, 2], dtype=np.int32), @@ -545,22 +545,22 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase): with self.cached_session(): self.assertEqual(self.evaluate(scatter).size, 0) - @test_util.run_deprecated_v1 def testRank3InvalidShape1(self): indices = array_ops.zeros([3, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = np.array([2, 2, 2]) with self.assertRaisesWithPredicateMatch( - ValueError, r"The outer \d+ dimensions of indices\.shape="): + (errors.InvalidArgumentError, ValueError), + r"Dimensions \[\d\,\d\) of indices\[shape="): self.scatter_nd(indices, updates, shape) - @test_util.run_deprecated_v1 def testRank3InvalidShape2(self): indices = array_ops.zeros([2, 2, 1], dtypes.int32) updates = array_ops.zeros([2, 2], dtypes.int32) shape = np.array([2, 2, 2]) with self.assertRaisesWithPredicateMatch( - ValueError, r"The inner \d+ dimensions of (input|output)\.shape="): + (errors.InvalidArgumentError, ValueError), + r"Dimensions \[\d\,\d\) of input\[shape="): self.scatter_nd(indices, updates, shape) @parameterized.parameters(set((True, context.executing_eagerly())))