Enable setting shape of _Arg tensor for shape of resource args
Post cl/240699970 the type of the argument would be inferred as tensor<Y x !tf.resource<tensor<Y x dtype>>> instead of tensor<* x !tf.resource<tensor<Y x dtype>>> as it would have been typed before for resource args. Update that an in addition generalize the usage of _output_shapes so that one can capture tensor<_output_shapes x !tf.resource<tensor<_handle_shapes x _handle_dtypes>>> PiperOrigin-RevId: 308812891 Change-Id: I4af67ca9d247ad6e63f01fb5a7206cb6dc519444
This commit is contained in:
parent
bddce29345
commit
2c1585c8a4
@ -34,49 +34,53 @@ REGISTER_SYSTEM_OP("_Arg")
|
|||||||
"_Arg node does not have attribute \"T\"");
|
"_Arg node does not have attribute \"T\"");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dtype_attr->type() == DT_RESOURCE) {
|
const AttrValue* shape_attr = context->attrs().Find("_output_shapes");
|
||||||
const AttrValue* dtype_attr = context->attrs().Find("_handle_dtypes");
|
if (shape_attr && shape_attr->has_list()) {
|
||||||
const AttrValue* shape_attr = context->attrs().Find("_handle_shapes");
|
if (shape_attr->list().shape().empty()) {
|
||||||
if (dtype_attr && shape_attr) {
|
return errors::InvalidArgument(
|
||||||
if (dtype_attr->list().type().empty()) {
|
"Invalid \"_output_shapes\" attribute value for _Arg node: ",
|
||||||
return errors::InvalidArgument(
|
shape_attr->DebugString());
|
||||||
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
|
|
||||||
shape_attr->DebugString());
|
|
||||||
}
|
|
||||||
if (shape_attr->list().shape().empty()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
|
|
||||||
shape_attr->DebugString());
|
|
||||||
}
|
|
||||||
DataType dtype = dtype_attr->list().type(0);
|
|
||||||
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
|
||||||
shape_inference::ShapeHandle shape_handle;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
|
||||||
context->set_output(0, shape_handle);
|
|
||||||
context->set_output_handle_shapes_and_types(
|
|
||||||
0, std::vector<shape_inference::ShapeAndType>{
|
|
||||||
{shape_handle, dtype}});
|
|
||||||
} else {
|
|
||||||
context->set_output(0, context->UnknownShape());
|
|
||||||
}
|
}
|
||||||
|
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||||
|
shape_inference::ShapeHandle shape_handle;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||||
|
context->set_output(0, shape_handle);
|
||||||
} else {
|
} else {
|
||||||
const AttrValue* shape_attr = context->attrs().Find("_output_shapes");
|
context->set_output(0, context->UnknownShape());
|
||||||
if (shape_attr && shape_attr->has_list()) {
|
|
||||||
if (shape_attr->list().shape().empty()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Invalid \"_output_shapes\" attribute value for _Arg node: ",
|
|
||||||
shape_attr->DebugString());
|
|
||||||
}
|
|
||||||
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
|
||||||
shape_inference::ShapeHandle shape_handle;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
|
||||||
context->set_output(0, shape_handle);
|
|
||||||
} else {
|
|
||||||
context->set_output(0, context->UnknownShape());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (dtype_attr->type() != DT_RESOURCE) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the argument is for a resource type, then also try to infer the
|
||||||
|
// type of the tensor store in the resource type.
|
||||||
|
dtype_attr = context->attrs().Find("_handle_dtypes");
|
||||||
|
shape_attr = context->attrs().Find("_handle_shapes");
|
||||||
|
// If either the shape or type attribute is not set then simply return
|
||||||
|
// with unknown output set above.
|
||||||
|
if (!dtype_attr || !shape_attr) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dtype_attr->list().type().empty()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
|
||||||
|
shape_attr->DebugString());
|
||||||
|
}
|
||||||
|
if (shape_attr->list().shape().empty()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
|
||||||
|
shape_attr->DebugString());
|
||||||
|
}
|
||||||
|
DataType dtype = dtype_attr->list().type(0);
|
||||||
|
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||||
|
shape_inference::ShapeHandle shape_handle;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||||
|
context->set_output_handle_shapes_and_types(
|
||||||
|
0, std::vector<shape_inference::ShapeAndType>{{shape_handle, dtype}});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
@ -86,17 +90,15 @@ output: The argument.
|
|||||||
index: This argument is the index-th argument of the function.
|
index: This argument is the index-th argument of the function.
|
||||||
|
|
||||||
Attributes for shape inference:
|
Attributes for shape inference:
|
||||||
1. _output_shapes: this attribute can be set on an _Arg node producing
|
1. _output_shapes: this attribute should contain a list of TensorShapeProto
|
||||||
non-resource output(s). If set, its value should contain a list of
|
describing the shape(s) of the tensor(s) this _Arg node will produce. If set,
|
||||||
TensorShapeProto describing the shape(s) of the tensor(s) this _Arg node will
|
_Arg node's shape inference function will use it as the node's output shapes.
|
||||||
produce. If set, _Arg node's shape inference function will use it as the
|
|
||||||
node's output shapes.
|
|
||||||
2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
|
2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
|
||||||
node producing resource output(s). If set, value of _handle_dtypes should
|
node producing resource output(s). If set, value of _handle_dtypes should
|
||||||
contain the dtype(s) of the resource(s) and value of _handle_shapes should
|
contain the dtype(s) of the resource(s) and value of _handle_shapes should
|
||||||
contain the shape(s) of the resource(s). If both attributes are set, _Arg
|
contain the shape(s) of the resource(s). If both attributes are set, _Arg
|
||||||
node's shape inference function will use their values as the node's output
|
node's shape inference function will use their values as the node's output
|
||||||
type(s) and shape(s).
|
handle's type(s) and shape(s).
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_SYSTEM_OP("_DeviceArg")
|
REGISTER_SYSTEM_OP("_DeviceArg")
|
||||||
|
@ -21,6 +21,36 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TEST(FunctionalOpsTest, Arg_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("_Arg");
|
||||||
|
std::vector<DataType> out_type_list;
|
||||||
|
out_type_list.emplace_back(DT_RESOURCE);
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("test", "_Arg")
|
||||||
|
.Attr("T", DataType::DT_RESOURCE)
|
||||||
|
.Attr("index", 0)
|
||||||
|
.Attr("_output_shapes", {TensorShape({5, 4})})
|
||||||
|
.Attr("_handle_shapes", {TensorShape({3, 7})})
|
||||||
|
.Attr("_handle_dtypes", {DataType::DT_FLOAT})
|
||||||
|
.Finalize(&op.node_def));
|
||||||
|
|
||||||
|
const OpRegistrationData* op_reg_data;
|
||||||
|
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
|
||||||
|
shape_inference::InferenceContext c(
|
||||||
|
op.graph_def_version, op.node_def, op_reg_data->op_def,
|
||||||
|
std::vector<shape_inference::ShapeHandle>{}, op.input_tensors, {}, {});
|
||||||
|
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
|
||||||
|
auto output = c.output(0);
|
||||||
|
ASSERT_EQ(c.Value(c.Rank(output)), 2);
|
||||||
|
EXPECT_EQ(c.Value(c.Dim(output, 0)), 5);
|
||||||
|
EXPECT_EQ(c.Value(c.Dim(output, 1)), 4);
|
||||||
|
|
||||||
|
auto outputs = c.output_handle_shapes_and_types(0);
|
||||||
|
ASSERT_EQ(outputs->size(), 1);
|
||||||
|
EXPECT_EQ(outputs->front().dtype, DataType::DT_FLOAT);
|
||||||
|
EXPECT_EQ(c.Value(c.Dim(outputs->front().shape, 0)), 3);
|
||||||
|
EXPECT_EQ(c.Value(c.Dim(outputs->front().shape, 1)), 7);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) {
|
TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) {
|
||||||
ShapeInferenceTestOp op("SymbolicGradient");
|
ShapeInferenceTestOp op("SymbolicGradient");
|
||||||
int num_inputs = 4;
|
int num_inputs = 4;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user