Check that rank(updates) = rank(indices + params[1:]) in resource_scatter_update, to match behavior of V1 scatter_update and shape function.

PiperOrigin-RevId: 316032217
Change-Id: I744299b73815a457c23be9dccee867b20116055b
This commit is contained in:
A. Unique TensorFlower 2020-06-11 20:23:33 -07:00 committed by TensorFlower Gardener
parent ffc7592c82
commit b62ad45a87
2 changed files with 20 additions and 0 deletions

View File

@ -887,6 +887,17 @@ class ResourceScatterUpdateOp : public OpKernel {
const Tensor& indices = c->input(1);
const Tensor& updates = c->input(2);
// Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
OP_REQUIRES(c,
updates.dims() == 0 ||
updates.dims() == indices.dims() + params->dims() - 1,
errors::InvalidArgument(
"Must have updates.shape = indices.shape + "
"params.shape[1:] or updates.shape = [], got ",
"updates.shape ", updates.shape().DebugString(),
", indices.shape ", indices.shape().DebugString(),
", params.shape ", params->shape().DebugString()));
// Check that we have enough index space
const int64 N_big = indices.NumElements();
OP_REQUIRES(

View File

@ -58,6 +58,15 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) {
// Resolve shape on first updates dimension.
INFER_OK(op, "[1,2];[3];[?,2]", "in0");
// Allow the update to be a scalar.
INFER_OK(op, "[1,2];[3];?", "in0");
// Allow a scalar index.
INFER_OK(op, "[1,2];[];[2]", "in0");
// Check the requirement updates.shape = indices.shape + ref.shape[1:].
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[2];[];[2]");
}
TEST(StateOpsTest, TemporaryVariable_ShapeFn) {