Check that rank(updates) = rank(indices + params[1:]) in resource_scatter_update, to match behavior of V1 scatter_update and shape function.
PiperOrigin-RevId: 316032217 Change-Id: I744299b73815a457c23be9dccee867b20116055b
This commit is contained in:
parent
ffc7592c82
commit
b62ad45a87
@ -887,6 +887,17 @@ class ResourceScatterUpdateOp : public OpKernel {
|
||||
const Tensor& indices = c->input(1);
|
||||
const Tensor& updates = c->input(2);
|
||||
|
||||
// Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
|
||||
OP_REQUIRES(c,
|
||||
updates.dims() == 0 ||
|
||||
updates.dims() == indices.dims() + params->dims() - 1,
|
||||
errors::InvalidArgument(
|
||||
"Must have updates.shape = indices.shape + "
|
||||
"params.shape[1:] or updates.shape = [], got ",
|
||||
"updates.shape ", updates.shape().DebugString(),
|
||||
", indices.shape ", indices.shape().DebugString(),
|
||||
", params.shape ", params->shape().DebugString()));
|
||||
|
||||
// Check that we have enough index space
|
||||
const int64 N_big = indices.NumElements();
|
||||
OP_REQUIRES(
|
||||
|
@ -58,6 +58,15 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) {
|
||||
|
||||
// Resolve shape on first updates dimension.
|
||||
INFER_OK(op, "[1,2];[3];[?,2]", "in0");
|
||||
|
||||
// Allow the update to be a scalar.
|
||||
INFER_OK(op, "[1,2];[3];?", "in0");
|
||||
|
||||
// Allow a scalar index.
|
||||
INFER_OK(op, "[1,2];[];[2]", "in0");
|
||||
|
||||
// Check the requirement updates.shape = indices.shape + ref.shape[1:].
|
||||
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[2];[];[2]");
|
||||
}
|
||||
|
||||
TEST(StateOpsTest, TemporaryVariable_ShapeFn) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user