Enable setting shape of _Arg tensor for shape of resource args
Post cl/240699970 the type of the argument would be inferred as tensor<Y x !tf.resource<tensor<Y x dtype>>> instead of tensor<* x !tf.resource<tensor<Y x dtype>>> as it would have been typed before for resource args. Update that an in addition generalize the usage of _output_shapes so that one can capture tensor<_output_shapes x !tf.resource<tensor<_handle_shapes x _handle_dtypes>>> PiperOrigin-RevId: 308812891 Change-Id: I4af67ca9d247ad6e63f01fb5a7206cb6dc519444
This commit is contained in:
parent
bddce29345
commit
2c1585c8a4
@ -34,49 +34,53 @@ REGISTER_SYSTEM_OP("_Arg")
|
||||
"_Arg node does not have attribute \"T\"");
|
||||
}
|
||||
|
||||
if (dtype_attr->type() == DT_RESOURCE) {
|
||||
const AttrValue* dtype_attr = context->attrs().Find("_handle_dtypes");
|
||||
const AttrValue* shape_attr = context->attrs().Find("_handle_shapes");
|
||||
if (dtype_attr && shape_attr) {
|
||||
if (dtype_attr->list().type().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
|
||||
shape_attr->DebugString());
|
||||
}
|
||||
if (shape_attr->list().shape().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
|
||||
shape_attr->DebugString());
|
||||
}
|
||||
DataType dtype = dtype_attr->list().type(0);
|
||||
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||
context->set_output(0, shape_handle);
|
||||
context->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{
|
||||
{shape_handle, dtype}});
|
||||
} else {
|
||||
context->set_output(0, context->UnknownShape());
|
||||
const AttrValue* shape_attr = context->attrs().Find("_output_shapes");
|
||||
if (shape_attr && shape_attr->has_list()) {
|
||||
if (shape_attr->list().shape().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid \"_output_shapes\" attribute value for _Arg node: ",
|
||||
shape_attr->DebugString());
|
||||
}
|
||||
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||
context->set_output(0, shape_handle);
|
||||
} else {
|
||||
const AttrValue* shape_attr = context->attrs().Find("_output_shapes");
|
||||
if (shape_attr && shape_attr->has_list()) {
|
||||
if (shape_attr->list().shape().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid \"_output_shapes\" attribute value for _Arg node: ",
|
||||
shape_attr->DebugString());
|
||||
}
|
||||
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||
context->set_output(0, shape_handle);
|
||||
} else {
|
||||
context->set_output(0, context->UnknownShape());
|
||||
}
|
||||
context->set_output(0, context->UnknownShape());
|
||||
}
|
||||
|
||||
if (dtype_attr->type() != DT_RESOURCE) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If the argument is for a resource type, then also try to infer the
|
||||
// type of the tensor store in the resource type.
|
||||
dtype_attr = context->attrs().Find("_handle_dtypes");
|
||||
shape_attr = context->attrs().Find("_handle_shapes");
|
||||
// If either the shape or type attribute is not set then simply return
|
||||
// with unknown output set above.
|
||||
if (!dtype_attr || !shape_attr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (dtype_attr->list().type().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
|
||||
shape_attr->DebugString());
|
||||
}
|
||||
if (shape_attr->list().shape().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
|
||||
shape_attr->DebugString());
|
||||
}
|
||||
DataType dtype = dtype_attr->list().type(0);
|
||||
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||
context->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{{shape_handle, dtype}});
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
@ -86,17 +90,15 @@ output: The argument.
|
||||
index: This argument is the index-th argument of the function.
|
||||
|
||||
Attributes for shape inference:
|
||||
1. _output_shapes: this attribute can be set on an _Arg node producing
|
||||
non-resource output(s). If set, its value should contain a list of
|
||||
TensorShapeProto describing the shape(s) of the tensor(s) this _Arg node will
|
||||
produce. If set, _Arg node's shape inference function will use it as the
|
||||
node's output shapes.
|
||||
1. _output_shapes: this attribute should contain a list of TensorShapeProto
|
||||
describing the shape(s) of the tensor(s) this _Arg node will produce. If set,
|
||||
_Arg node's shape inference function will use it as the node's output shapes.
|
||||
2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
|
||||
node producing resource output(s). If set, value of _handle_dtypes should
|
||||
contain the dtype(s) of the resource(s) and value of _handle_shapes should
|
||||
contain the shape(s) of the resource(s). If both attributes are set, _Arg
|
||||
node's shape inference function will use their values as the node's output
|
||||
type(s) and shape(s).
|
||||
handle's type(s) and shape(s).
|
||||
)doc");
|
||||
|
||||
REGISTER_SYSTEM_OP("_DeviceArg")
|
||||
|
@ -21,6 +21,36 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(FunctionalOpsTest, Arg_ShapeFn) {
|
||||
ShapeInferenceTestOp op("_Arg");
|
||||
std::vector<DataType> out_type_list;
|
||||
out_type_list.emplace_back(DT_RESOURCE);
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "_Arg")
|
||||
.Attr("T", DataType::DT_RESOURCE)
|
||||
.Attr("index", 0)
|
||||
.Attr("_output_shapes", {TensorShape({5, 4})})
|
||||
.Attr("_handle_shapes", {TensorShape({3, 7})})
|
||||
.Attr("_handle_dtypes", {DataType::DT_FLOAT})
|
||||
.Finalize(&op.node_def));
|
||||
|
||||
const OpRegistrationData* op_reg_data;
|
||||
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
|
||||
shape_inference::InferenceContext c(
|
||||
op.graph_def_version, op.node_def, op_reg_data->op_def,
|
||||
std::vector<shape_inference::ShapeHandle>{}, op.input_tensors, {}, {});
|
||||
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
|
||||
auto output = c.output(0);
|
||||
ASSERT_EQ(c.Value(c.Rank(output)), 2);
|
||||
EXPECT_EQ(c.Value(c.Dim(output, 0)), 5);
|
||||
EXPECT_EQ(c.Value(c.Dim(output, 1)), 4);
|
||||
|
||||
auto outputs = c.output_handle_shapes_and_types(0);
|
||||
ASSERT_EQ(outputs->size(), 1);
|
||||
EXPECT_EQ(outputs->front().dtype, DataType::DT_FLOAT);
|
||||
EXPECT_EQ(c.Value(c.Dim(outputs->front().shape, 0)), 3);
|
||||
EXPECT_EQ(c.Value(c.Dim(outputs->front().shape, 1)), 7);
|
||||
}
|
||||
|
||||
TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) {
|
||||
ShapeInferenceTestOp op("SymbolicGradient");
|
||||
int num_inputs = 4;
|
||||
|
Loading…
x
Reference in New Issue
Block a user