From a01cf466aac96a2745e26e8626a45468c8d9516f Mon Sep 17 00:00:00 2001 From: Rohan Jain <rohanj@google.com> Date: Wed, 12 Aug 2020 21:42:55 -0700 Subject: [PATCH] Unifying the scatter_nd* type ops shape inference code. It was duplicated in two places. Also cleaned up the error messages a bit to remove references to inner and outer dimensions of tensors and directly reference which dimensions the error message is referring to. Also unifying eager and graph mode error messages and removing some run_deprecated_v1 annotations as a result Added some c++ shape inference tests as well. PiperOrigin-RevId: 326378038 Change-Id: I58ab87ef0c476049da79c5896838a3b92649b772 --- tensorflow/core/framework/common_shape_fns.cc | 61 ++++++-------- tensorflow/core/framework/common_shape_fns.h | 5 +- tensorflow/core/kernels/scatter_nd_op.cc | 71 +++++++++------- tensorflow/core/kernels/scatter_nd_op_test.cc | 11 ++- tensorflow/core/ops/array_ops.cc | 83 +++---------------- tensorflow/core/ops/array_ops_test.cc | 33 ++++++++ tensorflow/core/ops/state_ops.cc | 36 +++++--- tensorflow/core/ops/state_ops_test.cc | 22 +++++ .../kernel_tests/scatter_nd_ops_test.py | 22 ++--- 9 files changed, 180 insertions(+), 164 deletions(-) diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 36ae36e7b74..8157f4ee01d 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -2257,66 +2257,57 @@ Status GatherNdShape(InferenceContext* c) { return Status::OK(); } -Status ScatterNdUpdateShape(InferenceContext* c) { - ShapeHandle input_shape = c->input(0); - if (c->input_handle_shapes_and_types(0) != nullptr) { - // This is called for tf.scatter_nd_update; input is a Variable handle. - const auto& shape_and_type = *(c->input_handle_shapes_and_types(0)); - if (shape_and_type.size() == 1) { - input_shape = shape_and_type[0].shape; - } - } - ShapeHandle indices_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); - ShapeHandle updates_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); - +Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, + ShapeHandle updates_shape, + ShapeHandle input_shape) { if (c->Value(c->NumElements(input_shape)) == 0 && (c->Value(c->NumElements(indices_shape)) > 0 || c->Value(c->NumElements(updates_shape)) > 0)) { return errors::InvalidArgument( - "Indices and updates specified for empty output shape"); + "Indices and updates specified for empty input"); } if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { - const int64 num_outer_dims = c->Rank(indices_shape) - 1; - const DimensionHandle index_size = c->Dim(indices_shape, -1); + const int64 outer_dims = c->Rank(indices_shape) - 1; + const DimensionHandle ixdim = c->Dim(indices_shape, -1); // We can only do more validation if the last dimension of indices // is a known value. - if (c->ValueKnown(index_size)) { - const int64 ix = c->Value(index_size); + if (c->ValueKnown(ixdim)) { + int64 ix = c->Value(ixdim); ShapeHandle unused; ShapeHandle prefix_indices; TF_RETURN_IF_ERROR( - c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices)); + c->Subshape(indices_shape, 0, outer_dims, &prefix_indices)); ShapeHandle prefix_updates; TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates)); + c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); Status s = c->Merge(prefix_indices, prefix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( - "The outer ", num_outer_dims, - " dimensions of indices.shape=", c->DebugString(indices_shape), - " must match the outer ", num_outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); + "Dimensions [0,", outer_dims, + ") of indices[shape=", c->DebugString(indices_shape), + "] = ", c->DebugString(prefix_indices), + " must match dimensions [0,", outer_dims, + ") of updates[shape=", c->DebugString(updates_shape), + "] = ", c->DebugString(prefix_updates), ": ", s.error_message()); } - ShapeHandle input_suffix; - TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix)); + ShapeHandle suffix_output; + TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output)); ShapeHandle suffix_updates; TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, num_outer_dims, &suffix_updates)); - s = c->Merge(input_suffix, suffix_updates, &unused); + c->Subshape(updates_shape, outer_dims, &suffix_updates)); + s = c->Merge(suffix_output, suffix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( - "The inner ", c->Rank(input_shape) - ix, - " dimensions of input.shape=", c->DebugString(input_shape), - " must match the inner ", c->Rank(updates_shape) - num_outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); + "Dimensions [", ix, ",", c->Rank(input_shape), + ") of input[shape=", c->DebugString(input_shape), + "] = ", c->DebugString(suffix_output), " must match dimensions [", + outer_dims, ",", c->Rank(updates_shape), + ") of updates[shape=", c->DebugString(updates_shape), + "] = ", c->DebugString(suffix_updates), ": ", s.error_message()); } } } diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 218400c2435..f3e02638f54 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -241,8 +241,9 @@ Status ValidateVariableResourceHandle( // Shape function for GatherNd operations. Status GatherNdShape(InferenceContext* c); -// Shape function for ScatterNd update/add/sub/... operations. -Status ScatterNdUpdateShape(InferenceContext* c); +// Helper shape function for ScatterNd.../TensorScatter... operations. +Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, + ShapeHandle updates_shape, ShapeHandle input_shape); // Shape function for ops with an explicit "shape" attribute. Status ExplicitShape(InferenceContext* c); diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 88bf16d974e..942740b9af3 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -100,29 +100,31 @@ class ScatterNdOp : public OpKernel { const int64 outer_dims = indices.shape().dims() - 1; for (int i = 0; i < outer_dims; ++i) { - OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i), - errors::InvalidArgument( - "Outer dimensions of indices and update must match. " - "Indices shape: ", - indices.shape().DebugString(), - ", updates shape:", updates.shape().DebugString())); + OP_REQUIRES( + c, indices.shape().dim_size(i) == updates.shape().dim_size(i), + errors::InvalidArgument( + "Dimensions [0,", outer_dims, + ") of indices[shape=", indices.shape().DebugString(), + "] must match dimensions [0,", outer_dims, + ") of updates[shape=", updates.shape().DebugString(), "]")); } const int64 ix = indices.shape().dim_size(outer_dims); - OP_REQUIRES( - c, updates.shape().dims() - outer_dims == shape.dims() - ix, - errors::InvalidArgument("Inner dimensions of output shape must match " - "inner dimensions of updates shape. Output: ", - shape.DebugString(), - " updates: ", updates.shape().DebugString())); + OP_REQUIRES(c, updates.shape().dims() - outer_dims == shape.dims() - ix, + errors::InvalidArgument( + "Dimensions [", ix, ",", shape.dims(), ") of input[shape=", + shape.DebugString(), "] must match dimensions [", + outer_dims, ",", updates.shape().dims(), + ") of updates[shape=", updates.shape().DebugString(), "]")); + for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) { OP_REQUIRES( c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i), - errors::InvalidArgument( - "The inner ", shape.dims() - ix, - " dimensions of output.shape=", shape.DebugString(), - " must match the inner ", updates.shape().dims() - outer_dims, - " dimensions of updates.shape=", updates.shape().DebugString())); + errors::InvalidArgument("Dimensions [", ix, ",", shape.dims(), + ") of input[shape=", shape.DebugString(), + "] must match dimensions [", outer_dims, ",", + updates.shape().dims(), ") of updates[shape=", + updates.shape().DebugString(), "]")); } OP_REQUIRES(c, shape_input.dims() == 1, errors::InvalidArgument("Shape must be a vector")); @@ -602,30 +604,35 @@ Status ValidateUpdateShape(const TensorShape& params_shape, (indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1; const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1; - auto shape_err = [&]() { + auto shape_err_prefix = [&]() { return errors::InvalidArgument( - "Must have updates.shape = indices.shape[:batch_dim] + ", - "params_shape[slice_dim:], got updates.shape: ", - updates.shape().DebugString(), - ", indices.shape: ", indices.shape().DebugString(), - ", params_shape: ", params_shape.DebugString(), - ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim); + "Dimensions [0,", batch_dim, + ") of indices[shape=", indices.shape().DebugString(), + "] must match dimensions [0,", batch_dim, + ") of updates[shape=", updates.shape().DebugString(), "]"); + }; + auto shape_err_suffix = [&]() { + return errors::InvalidArgument( + "Dimensions [", slice_dim, ",", params_shape.dims(), + ") of input[shape=", params_shape.DebugString(), + "] must match dimensions [", slice_dim, ",", updates.dims(), + ") of updates[shape=", updates.shape().DebugString(), "]"); }; - if (updates.dims() < batch_dim) return shape_err(); + if (updates.dims() < batch_dim) return shape_err_prefix(); if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) { - return shape_err(); + return shape_err_suffix(); } if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) { - return shape_err(); + return shape_err_suffix(); } for (int d = 0; d < batch_dim; ++d) { - if (updates.dim_size(d) != indices.dim_size(d)) return shape_err(); + if (updates.dim_size(d) != indices.dim_size(d)) return shape_err_prefix(); } for (int d = 0; d < updates.dims() - batch_dim; ++d) { if (updates.dim_size(d + batch_dim) != params_shape.dim_size(d + slice_dim)) { - return shape_err(); + return shape_err_suffix(); } } return Status::OK(); @@ -654,9 +661,9 @@ Status PrepareAndValidateInputs(const TensorShape& params_shape, if (updates.dim_size(0) != indices.dim_size(0)) { return errors::InvalidArgument( - "The outermost dimension of updates and indices ", - "must match. Got indices.shape ", indices_shape.DebugString(), - ", updates.shape ", updates_shape.DebugString()); + "Dimensions [0,1) of indices[shape=", indices_shape.DebugString(), + "] = ", indices.dim_size(0), " must match dimensions [0,1) of updates[", + "shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0)); } TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates)); diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc index 1461831a1fb..9c31bed784f 100644 --- a/tensorflow/core/kernels/scatter_nd_op_test.cc +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -200,8 +200,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) { Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), - "The outermost dimension of updates and indices must match. Got " - "indices.shape [1,3,1], updates.shape [3,3]")) + "Dimensions [0,1) of indices[shape=[1,3,1]] = 1 must match dimensions " + "[0,1) of updates[shape=[3,3]] = 3")) << s; } @@ -217,7 +217,9 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004}); Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( - s.ToString(), "Must have updates.shape = indices.shape[:batch_dim]")) + s.ToString(), + "Dimensions [1,2) of input[shape=[5,3]] must match dimensions [1,2) of " + "updates[shape=[3,4]]")) << s; } @@ -233,7 +235,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), - "The outermost dimension of updates and indices must match.")) + "Dimensions [0,1) of indices[shape=[3,1]] = 3 must match dimensions [0,1)" + " of updates[shape=[2,3]] = 2")) << s; } diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 11bfb9a3346..b4dfe6187d5 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2974,73 +2974,6 @@ REGISTER_OP("QuantizedInstanceNorm") namespace { -Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, - ShapeHandle updates_shape, - ShapeHandle output_shape) { - if (c->Value(c->NumElements(output_shape)) == 0 && - (c->Value(c->NumElements(indices_shape)) > 0 || - c->Value(c->NumElements(updates_shape)) > 0)) { - return errors::InvalidArgument( - "Indices and updates specified for empty output shape"); - } - - if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { - const int64 outer_dims = c->Rank(indices_shape) - 1; - const DimensionHandle ixdim = c->Dim(indices_shape, -1); - - // We can only do more validation if the last dimension of indices - // is a known value. - if (c->ValueKnown(ixdim)) { - int64 ix = c->Value(ixdim); - ShapeHandle unused; - ShapeHandle prefix_indices; - TF_RETURN_IF_ERROR( - c->Subshape(indices_shape, 0, outer_dims, &prefix_indices)); - ShapeHandle prefix_updates; - TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); - - Status s = c->Merge(prefix_indices, prefix_updates, &unused); - if (!s.ok()) { - return errors::InvalidArgument( - "The outer ", outer_dims, - " dimensions of indices.shape=", c->DebugString(indices_shape), - " must match the outer ", outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); - } - - ShapeHandle suffix_output; - TF_RETURN_IF_ERROR(c->Subshape(output_shape, ix, &suffix_output)); - ShapeHandle suffix_updates; - TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, outer_dims, &suffix_updates)); - s = c->Merge(suffix_output, suffix_updates, &unused); - if (!s.ok()) { - return errors::InvalidArgument( - "The inner ", c->Rank(output_shape) - ix, - " dimensions of output.shape=", c->DebugString(output_shape), - " must match the inner ", c->Rank(updates_shape) - outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); - } - } - } - - c->set_output(0, output_shape); - return Status::OK(); -} - -Status ScatterNdShape(InferenceContext* c) { - ShapeHandle indices_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape)); - ShapeHandle updates_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); - ShapeHandle output_shape; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape)); - return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape); -} - Status ScatterNdTensorShape(InferenceContext* c) { ShapeHandle output_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape)); @@ -3048,7 +2981,8 @@ Status ScatterNdTensorShape(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); ShapeHandle updates_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); - return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape); + return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape, + output_shape); } } // namespace @@ -3088,7 +3022,16 @@ REGISTER_OP("ScatterNd") .Output("output: T") .Attr("T: type") .Attr("Tindices: {int32, int64}") - .SetShapeFn(ScatterNdShape); + .SetShapeFn([](InferenceContext* c) { + ShapeHandle indices_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape)); + ShapeHandle updates_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); + ShapeHandle output_shape; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape)); + return shape_inference::ScatterNdShapeHelper(c, indices_shape, + updates_shape, output_shape); + }); REGISTER_OP("TensorScatterUpdate") .Input("tensor: T") @@ -3142,7 +3085,7 @@ REGISTER_OP("ScatterNdNonAliasingAdd") .Output("output: T") .Attr("T: {numbertype, bool}") .Attr("Tindices: {int32, int64}") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdTensorShape); REGISTER_OP("FakeQuantWithMinMaxArgs") .Attr("min: float = -6.0") diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 1725bdbac39..412c926d386 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -27,6 +27,39 @@ limitations under the License. namespace tensorflow { +TEST(ArrayOpsTest, TensorScatterUpdate_ShapeFn) { + ShapeInferenceTestOp op("TensorScatterUpdate"); + + INFER_OK(op, "[4,3];[8,2];[8]", "in0"); + INFER_OK(op, "[?,?];[?,2];[?]", "in0"); + INFER_OK(op, "[?];[?];[?]", "in0"); + + INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, + "[];[?,2];[?]"); + INFER_ERROR("Indices and updates specified for empty input", op, + "[0,2,2];[8,2];[8]"); + INFER_ERROR( + "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match " + "dimensions [0,1) of updates[shape=[9]] = [9]", + op, "[?,?];[8,2];[9]"); + INFER_ERROR( + "Dimensions [2,2) of input[shape=[?,?]] = [] must match " + "dimensions [1,2) of updates[shape=[?,1]] = [1]", + op, "[?,?];[?,2];[?,1]"); +} + +TEST(ArrayOpsTest, ScatterNd_ShapeFn) { + ShapeInferenceTestOp op("ScatterNd"); + + INFER_OK(op, "[8,2];[8];[2]", "[?,?]"); + + INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[?,2];[?];[]"); + INFER_ERROR( + "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match " + "dimensions [0,1) of updates[shape=[9]] = [9]", + op, "[8,2];[9];[?]"); +} + TEST(ArrayOpsTest, UnravelIndex_ShapeFn) { ShapeInferenceTestOp op("UnravelIndex"); diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 500d5ec88b8..5d856396360 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -131,6 +131,22 @@ Status ScatterUpdateShape(InferenceContext* c) { return Status::OK(); } +Status ScatterNdUpdateShape(InferenceContext* c) { + ShapeHandle input_shape = c->input(0); + if (c->input_handle_shapes_and_types(0) != nullptr) { + const auto& shape_and_type = *(c->input_handle_shapes_and_types(0)); + if (!shape_and_type.empty()) { + input_shape = shape_and_type[0].shape; + } + } + ShapeHandle indices_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); + ShapeHandle updates_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); + return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape, + input_shape); +} + } // namespace REGISTER_OP("ScatterUpdate") @@ -211,7 +227,7 @@ REGISTER_OP("ScatterNdUpdate") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdUpdate") .Input("ref: resource") @@ -220,7 +236,7 @@ REGISTER_OP("ResourceScatterNdUpdate") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdAdd") .Input("ref: resource") @@ -229,7 +245,7 @@ REGISTER_OP("ResourceScatterNdAdd") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdSub") .Input("ref: resource") @@ -238,7 +254,7 @@ REGISTER_OP("ResourceScatterNdSub") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdMin") .Input("ref: resource") @@ -247,7 +263,7 @@ REGISTER_OP("ResourceScatterNdMin") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ResourceScatterNdMax") .Input("ref: resource") @@ -256,7 +272,7 @@ REGISTER_OP("ResourceScatterNdMax") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ScatterNdAdd") .Input("ref: Ref(T)") @@ -266,7 +282,7 @@ REGISTER_OP("ScatterNdAdd") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ScatterNdSub") .Input("ref: Ref(T)") @@ -276,7 +292,7 @@ REGISTER_OP("ScatterNdSub") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ScatterNdMax") .Input("ref: Ref(T)") @@ -286,7 +302,7 @@ REGISTER_OP("ScatterNdMax") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("ScatterNdMin") .Input("ref: Ref(T)") @@ -296,7 +312,7 @@ REGISTER_OP("ScatterNdMin") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(shape_inference::ScatterNdUpdateShape); + .SetShapeFn(ScatterNdUpdateShape); REGISTER_OP("CountUpTo") .Input("ref: Ref(T)") diff --git a/tensorflow/core/ops/state_ops_test.cc b/tensorflow/core/ops/state_ops_test.cc index a0caad4a49f..bc68cf46f03 100644 --- a/tensorflow/core/ops/state_ops_test.cc +++ b/tensorflow/core/ops/state_ops_test.cc @@ -69,6 +69,28 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) { INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[2];[];[2]"); } +TEST(StateOpsTest, ResourceScatterNdUpdate_ShapeFn) { + ShapeInferenceTestOp op("ResourceScatterNdUpdate"); + TF_ASSERT_OK(NodeDefBuilder("test", "ResourceScatterNdUpdate") + .Input("ref", 0, DT_RESOURCE) + .Input("indices", 0, DT_INT32) + .Input("updates", 1, DT_FLOAT) + .Finalize(&op.node_def)); + + std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types; + op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types); + op.input_resource_handle_shapes_and_types.push_back(nullptr); + op.input_resource_handle_shapes_and_types.push_back(nullptr); + shapes_and_types.emplace_back("[?,?]", DT_FLOAT); + INFER_OK(op, "[?];[?,2];[?]", ""); + INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, + "[?];[?,2];[]"); + INFER_ERROR( + "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match " + "dimensions [0,1) of updates[shape=[9]] = [9]", + op, "[?];[8,2];[9]"); +} + TEST(StateOpsTest, TemporaryVariable_ShapeFn) { ShapeInferenceTestOp op("TemporaryVariable"); TensorShape shape({1, 2, 3}); diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index c5e5e549ee7..d5843c1a766 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -331,24 +331,24 @@ class StatefulScatterNdTest(test.TestCase): self.evaluate(ref.initializer) self.assertAllEqual(expected_result, self.evaluate(scatter_update)) - @test_util.run_deprecated_v1 def testRank3InvalidShape1(self): indices = array_ops.zeros([3, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = np.array([2, 2, 2]) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) with self.assertRaisesWithPredicateMatch( - ValueError, r"The outer \d+ dimensions of indices\.shape="): + (errors.InvalidArgumentError, ValueError), + r"Dimensions \[\d,\d\) of indices\[shape="): state_ops.scatter_nd_update(ref, indices, updates) - @test_util.run_deprecated_v1 def testRank3InvalidShape2(self): indices = array_ops.zeros([2, 2, 1], dtypes.int32) updates = array_ops.zeros([2, 2], dtypes.int32) shape = np.array([2, 2, 2]) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) with self.assertRaisesWithPredicateMatch( - ValueError, r"The inner \d+ dimensions of input\.shape="): + (errors.InvalidArgumentError, ValueError), + r"Dimensions \[\d,\d\) of input\[shape="): state_ops.scatter_nd_update(ref, indices, updates) def testConcurrentUpdates(self): @@ -511,14 +511,14 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase): shape = array_ops.placeholder(dtypes.int32, shape=[None]) self.scatter_nd(indices, updates, shape) - @test_util.run_deprecated_v1 def testEmptyOutputShape1(self): indices = array_ops.zeros([2, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = constant_op.constant([0, 3, 2], dtypes.int32) with self.assertRaisesWithPredicateMatch( - ValueError, "Indices and updates specified for empty output shape"): + (errors.InvalidArgumentError, ValueError), + "Indices and updates specified for empty"): self.scatter_nd(indices, updates, shape) def testEmptyOutputShape2(self): @@ -529,7 +529,7 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase): with self.cached_session(): with self.assertRaisesOpError( - "Indices and updates specified for empty output"): + "Indices and updates specified for empty (input|output)"): self.scatter_nd(indices, updates, shape).eval( feed_dict={ indices: np.zeros([2, 2, 2], dtype=np.int32), @@ -545,22 +545,22 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase): with self.cached_session(): self.assertEqual(self.evaluate(scatter).size, 0) - @test_util.run_deprecated_v1 def testRank3InvalidShape1(self): indices = array_ops.zeros([3, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = np.array([2, 2, 2]) with self.assertRaisesWithPredicateMatch( - ValueError, r"The outer \d+ dimensions of indices\.shape="): + (errors.InvalidArgumentError, ValueError), + r"Dimensions \[\d\,\d\) of indices\[shape="): self.scatter_nd(indices, updates, shape) - @test_util.run_deprecated_v1 def testRank3InvalidShape2(self): indices = array_ops.zeros([2, 2, 1], dtypes.int32) updates = array_ops.zeros([2, 2], dtypes.int32) shape = np.array([2, 2, 2]) with self.assertRaisesWithPredicateMatch( - ValueError, r"The inner \d+ dimensions of (input|output)\.shape="): + (errors.InvalidArgumentError, ValueError), + r"Dimensions \[\d\,\d\) of input\[shape="): self.scatter_nd(indices, updates, shape) @parameterized.parameters(set((True, context.executing_eagerly())))