Enable setting shape of _Arg tensor for shape of resource args

Post cl/240699970 the type of the argument would be inferred as
  tensor<Y x !tf.resource<tensor<Y x dtype>>>
instead of
  tensor<* x !tf.resource<tensor<Y x dtype>>>
as it would have been typed before for resource args. Update that an in addition generalize the usage of _output_shapes so that one can capture
  tensor<_output_shapes x !tf.resource<tensor<_handle_shapes x _handle_dtypes>>>

PiperOrigin-RevId: 308812891
Change-Id: I4af67ca9d247ad6e63f01fb5a7206cb6dc519444
This commit is contained in:
Jacques Pienaar 2020-04-28 06:37:28 -07:00 committed by TensorFlower Gardener
parent bddce29345
commit 2c1585c8a4
2 changed files with 78 additions and 46 deletions

View File

@ -34,49 +34,53 @@ REGISTER_SYSTEM_OP("_Arg")
"_Arg node does not have attribute \"T\""); "_Arg node does not have attribute \"T\"");
} }
if (dtype_attr->type() == DT_RESOURCE) { const AttrValue* shape_attr = context->attrs().Find("_output_shapes");
const AttrValue* dtype_attr = context->attrs().Find("_handle_dtypes"); if (shape_attr && shape_attr->has_list()) {
const AttrValue* shape_attr = context->attrs().Find("_handle_shapes"); if (shape_attr->list().shape().empty()) {
if (dtype_attr && shape_attr) { return errors::InvalidArgument(
if (dtype_attr->list().type().empty()) { "Invalid \"_output_shapes\" attribute value for _Arg node: ",
return errors::InvalidArgument( shape_attr->DebugString());
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
if (shape_attr->list().shape().empty()) {
return errors::InvalidArgument(
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
DataType dtype = dtype_attr->list().type(0);
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output(0, shape_handle);
context->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{shape_handle, dtype}});
} else {
context->set_output(0, context->UnknownShape());
} }
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output(0, shape_handle);
} else { } else {
const AttrValue* shape_attr = context->attrs().Find("_output_shapes"); context->set_output(0, context->UnknownShape());
if (shape_attr && shape_attr->has_list()) {
if (shape_attr->list().shape().empty()) {
return errors::InvalidArgument(
"Invalid \"_output_shapes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output(0, shape_handle);
} else {
context->set_output(0, context->UnknownShape());
}
} }
if (dtype_attr->type() != DT_RESOURCE) {
return Status::OK();
}
// If the argument is for a resource type, then also try to infer the
// type of the tensor store in the resource type.
dtype_attr = context->attrs().Find("_handle_dtypes");
shape_attr = context->attrs().Find("_handle_shapes");
// If either the shape or type attribute is not set then simply return
// with unknown output set above.
if (!dtype_attr || !shape_attr) {
return Status::OK();
}
if (dtype_attr->list().type().empty()) {
return errors::InvalidArgument(
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
if (shape_attr->list().shape().empty()) {
return errors::InvalidArgument(
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
DataType dtype = dtype_attr->list().type(0);
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{shape_handle, dtype}});
return Status::OK(); return Status::OK();
}) })
.Doc(R"doc( .Doc(R"doc(
@ -86,17 +90,15 @@ output: The argument.
index: This argument is the index-th argument of the function. index: This argument is the index-th argument of the function.
Attributes for shape inference: Attributes for shape inference:
1. _output_shapes: this attribute can be set on an _Arg node producing 1. _output_shapes: this attribute should contain a list of TensorShapeProto
non-resource output(s). If set, its value should contain a list of describing the shape(s) of the tensor(s) this _Arg node will produce. If set,
TensorShapeProto describing the shape(s) of the tensor(s) this _Arg node will _Arg node's shape inference function will use it as the node's output shapes.
produce. If set, _Arg node's shape inference function will use it as the
node's output shapes.
2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg 2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
node producing resource output(s). If set, value of _handle_dtypes should node producing resource output(s). If set, value of _handle_dtypes should
contain the dtype(s) of the resource(s) and value of _handle_shapes should contain the dtype(s) of the resource(s) and value of _handle_shapes should
contain the shape(s) of the resource(s). If both attributes are set, _Arg contain the shape(s) of the resource(s). If both attributes are set, _Arg
node's shape inference function will use their values as the node's output node's shape inference function will use their values as the node's output
type(s) and shape(s). handle's type(s) and shape(s).
)doc"); )doc");
REGISTER_SYSTEM_OP("_DeviceArg") REGISTER_SYSTEM_OP("_DeviceArg")

View File

@ -21,6 +21,36 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
TEST(FunctionalOpsTest, Arg_ShapeFn) {
ShapeInferenceTestOp op("_Arg");
std::vector<DataType> out_type_list;
out_type_list.emplace_back(DT_RESOURCE);
TF_ASSERT_OK(NodeDefBuilder("test", "_Arg")
.Attr("T", DataType::DT_RESOURCE)
.Attr("index", 0)
.Attr("_output_shapes", {TensorShape({5, 4})})
.Attr("_handle_shapes", {TensorShape({3, 7})})
.Attr("_handle_dtypes", {DataType::DT_FLOAT})
.Finalize(&op.node_def));
const OpRegistrationData* op_reg_data;
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
shape_inference::InferenceContext c(
op.graph_def_version, op.node_def, op_reg_data->op_def,
std::vector<shape_inference::ShapeHandle>{}, op.input_tensors, {}, {});
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
auto output = c.output(0);
ASSERT_EQ(c.Value(c.Rank(output)), 2);
EXPECT_EQ(c.Value(c.Dim(output, 0)), 5);
EXPECT_EQ(c.Value(c.Dim(output, 1)), 4);
auto outputs = c.output_handle_shapes_and_types(0);
ASSERT_EQ(outputs->size(), 1);
EXPECT_EQ(outputs->front().dtype, DataType::DT_FLOAT);
EXPECT_EQ(c.Value(c.Dim(outputs->front().shape, 0)), 3);
EXPECT_EQ(c.Value(c.Dim(outputs->front().shape, 1)), 7);
}
TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) { TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) {
ShapeInferenceTestOp op("SymbolicGradient"); ShapeInferenceTestOp op("SymbolicGradient");
int num_inputs = 4; int num_inputs = 4;