From b62ad45a875e5e53859261e70d74acac36df4a4a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 11 Jun 2020 20:23:33 -0700 Subject: [PATCH] Check that rank(updates) = rank(indices + params[1:]) in resource_scatter_update, to match behavior of V1 scatter_update and shape function. PiperOrigin-RevId: 316032217 Change-Id: I744299b73815a457c23be9dccee867b20116055b --- tensorflow/core/kernels/resource_variable_ops.cc | 11 +++++++++++ tensorflow/core/ops/state_ops_test.cc | 9 +++++++++ 2 files changed, 20 insertions(+) diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index b606d411a3d..0fc1d53749f 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -887,6 +887,17 @@ class ResourceScatterUpdateOp : public OpKernel { const Tensor& indices = c->input(1); const Tensor& updates = c->input(2); + // Check that rank(updates.shape) = rank(indices.shape + params.shape[1:]) + OP_REQUIRES(c, + updates.dims() == 0 || + updates.dims() == indices.dims() + params->dims() - 1, + errors::InvalidArgument( + "Must have updates.shape = indices.shape + " + "params.shape[1:] or updates.shape = [], got ", + "updates.shape ", updates.shape().DebugString(), + ", indices.shape ", indices.shape().DebugString(), + ", params.shape ", params->shape().DebugString())); + // Check that we have enough index space const int64 N_big = indices.NumElements(); OP_REQUIRES( diff --git a/tensorflow/core/ops/state_ops_test.cc b/tensorflow/core/ops/state_ops_test.cc index 6d05dd0b96c..a0caad4a49f 100644 --- a/tensorflow/core/ops/state_ops_test.cc +++ b/tensorflow/core/ops/state_ops_test.cc @@ -58,6 +58,15 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) { // Resolve shape on first updates dimension. INFER_OK(op, "[1,2];[3];[?,2]", "in0"); + + // Allow the update to be a scalar. + INFER_OK(op, "[1,2];[3];?", "in0"); + + // Allow a scalar index. + INFER_OK(op, "[1,2];[];[2]", "in0"); + + // Check the requirement updates.shape = indices.shape + ref.shape[1:]. + INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[2];[];[2]"); } TEST(StateOpsTest, TemporaryVariable_ShapeFn) {