Enable setting shape of _Arg tensor for shape of resource args

Post cl/240699970 the type of the argument would be inferred as
  tensor<Y x !tf.resource<tensor<Y x dtype>>>
instead of
  tensor<* x !tf.resource<tensor<Y x dtype>>>
as it would have been typed before for resource args. Update that an in addition generalize the usage of _output_shapes so that one can capture
  tensor<_output_shapes x !tf.resource<tensor<_handle_shapes x _handle_dtypes>>>

PiperOrigin-RevId: 308812891
Change-Id: I4af67ca9d247ad6e63f01fb5a7206cb6dc519444
This commit is contained in:
Jacques Pienaar 2020-04-28 06:37:28 -07:00 committed by TensorFlower Gardener
parent bddce29345
commit 2c1585c8a4
2 changed files with 78 additions and 46 deletions

View File

@ -34,49 +34,53 @@ REGISTER_SYSTEM_OP("_Arg")
"_Arg node does not have attribute \"T\"");
}
if (dtype_attr->type() == DT_RESOURCE) {
const AttrValue* dtype_attr = context->attrs().Find("_handle_dtypes");
const AttrValue* shape_attr = context->attrs().Find("_handle_shapes");
if (dtype_attr && shape_attr) {
if (dtype_attr->list().type().empty()) {
return errors::InvalidArgument(
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
if (shape_attr->list().shape().empty()) {
return errors::InvalidArgument(
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
DataType dtype = dtype_attr->list().type(0);
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output(0, shape_handle);
context->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{shape_handle, dtype}});
} else {
context->set_output(0, context->UnknownShape());
const AttrValue* shape_attr = context->attrs().Find("_output_shapes");
if (shape_attr && shape_attr->has_list()) {
if (shape_attr->list().shape().empty()) {
return errors::InvalidArgument(
"Invalid \"_output_shapes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output(0, shape_handle);
} else {
const AttrValue* shape_attr = context->attrs().Find("_output_shapes");
if (shape_attr && shape_attr->has_list()) {
if (shape_attr->list().shape().empty()) {
return errors::InvalidArgument(
"Invalid \"_output_shapes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output(0, shape_handle);
} else {
context->set_output(0, context->UnknownShape());
}
context->set_output(0, context->UnknownShape());
}
if (dtype_attr->type() != DT_RESOURCE) {
return Status::OK();
}
// If the argument is for a resource type, then also try to infer the
// type of the tensor store in the resource type.
dtype_attr = context->attrs().Find("_handle_dtypes");
shape_attr = context->attrs().Find("_handle_shapes");
// If either the shape or type attribute is not set then simply return
// with unknown output set above.
if (!dtype_attr || !shape_attr) {
return Status::OK();
}
if (dtype_attr->list().type().empty()) {
return errors::InvalidArgument(
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
if (shape_attr->list().shape().empty()) {
return errors::InvalidArgument(
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
DataType dtype = dtype_attr->list().type(0);
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{shape_handle, dtype}});
return Status::OK();
})
.Doc(R"doc(
@ -86,17 +90,15 @@ output: The argument.
index: This argument is the index-th argument of the function.
Attributes for shape inference:
1. _output_shapes: this attribute can be set on an _Arg node producing
non-resource output(s). If set, its value should contain a list of
TensorShapeProto describing the shape(s) of the tensor(s) this _Arg node will
produce. If set, _Arg node's shape inference function will use it as the
node's output shapes.
1. _output_shapes: this attribute should contain a list of TensorShapeProto
describing the shape(s) of the tensor(s) this _Arg node will produce. If set,
_Arg node's shape inference function will use it as the node's output shapes.
2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
node producing resource output(s). If set, value of _handle_dtypes should
contain the dtype(s) of the resource(s) and value of _handle_shapes should
contain the shape(s) of the resource(s). If both attributes are set, _Arg
node's shape inference function will use their values as the node's output
type(s) and shape(s).
handle's type(s) and shape(s).
)doc");
REGISTER_SYSTEM_OP("_DeviceArg")

View File

@ -21,6 +21,36 @@ limitations under the License.
namespace tensorflow {
TEST(FunctionalOpsTest, Arg_ShapeFn) {
ShapeInferenceTestOp op("_Arg");
std::vector<DataType> out_type_list;
out_type_list.emplace_back(DT_RESOURCE);
TF_ASSERT_OK(NodeDefBuilder("test", "_Arg")
.Attr("T", DataType::DT_RESOURCE)
.Attr("index", 0)
.Attr("_output_shapes", {TensorShape({5, 4})})
.Attr("_handle_shapes", {TensorShape({3, 7})})
.Attr("_handle_dtypes", {DataType::DT_FLOAT})
.Finalize(&op.node_def));
const OpRegistrationData* op_reg_data;
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
shape_inference::InferenceContext c(
op.graph_def_version, op.node_def, op_reg_data->op_def,
std::vector<shape_inference::ShapeHandle>{}, op.input_tensors, {}, {});
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
auto output = c.output(0);
ASSERT_EQ(c.Value(c.Rank(output)), 2);
EXPECT_EQ(c.Value(c.Dim(output, 0)), 5);
EXPECT_EQ(c.Value(c.Dim(output, 1)), 4);
auto outputs = c.output_handle_shapes_and_types(0);
ASSERT_EQ(outputs->size(), 1);
EXPECT_EQ(outputs->front().dtype, DataType::DT_FLOAT);
EXPECT_EQ(c.Value(c.Dim(outputs->front().shape, 0)), 3);
EXPECT_EQ(c.Value(c.Dim(outputs->front().shape, 1)), 7);
}
TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) {
ShapeInferenceTestOp op("SymbolicGradient");
int num_inputs = 4;