_Arg node requires _handle_dtypes and _handle_shapes attr when T is DT_RESOURCE.

Otherwise, function node with DT_RESOURCE type input fails to infer output shapes.

PiperOrigin-RevId: 308927584
Change-Id: I653b293aaa36ffd4a6bed68e4692398eebd9d578
This commit is contained in:
Doe Hyun Yoon 2020-04-28 17:09:08 -07:00 committed by TensorFlower Gardener
parent 62856c9366
commit 7028dbefc0
3 changed files with 560 additions and 264 deletions

View File

@ -804,14 +804,47 @@ class SymbolicShapeRefiner {
int output_port_num = input_tensor.index();
AttrValue attr_output_shape;
TensorShapeProto proto;
const auto& handle = input_ic->output(output_port_num);
const auto handle = input_ic->output(output_port_num);
input_ic->ShapeHandleToProto(handle, &proto);
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
NormalizeShapeForOutput(&proto);
// _Arg op's output shape uses _output_shapes attr.
AttrValue output_attr;
output_attr.mutable_list()->add_shape()->Swap(&proto);
(*fun_node->mutable_attr())["_output_shapes"] = output_attr;
// If dtype is DT_RESOURCE, ops that read _Arg op use _handle_dtypes and
// _handle_shapes attr for its shapes and dtypes.
if (fun_input.data_type == DT_RESOURCE) {
auto* shapes_and_types =
input_ic->output_handle_shapes_and_types(output_port_num);
if (shapes_and_types != nullptr && !shapes_and_types->empty()) {
AttrValue dtype_attr;
AttrValue shape_attr;
for (const auto& shape_and_type : *shapes_and_types) {
const auto& dtype = shape_and_type.dtype;
const auto& shape_handle = shape_and_type.shape;
dtype_attr.mutable_list()->add_type(dtype);
input_ic->ShapeHandleToProto(
shape_handle, shape_attr.mutable_list()->add_shape());
}
(*fun_node->mutable_attr())["_handle_dtypes"] = dtype_attr;
(*fun_node->mutable_attr())["_handle_shapes"] = shape_attr;
} else {
// Note that we do not return error here, even if the input node does
// not have shapes_and_types. Within the function, we cannot infer the
// output shape of the DT_RESOURCE input; hence, potentially unknown
// shapes/dims in the function output shapes.
VLOG(2)
<< "A function node (" << function_node->name()
<< ") has input with DT_RESOURCE, but the input node does not "
<< "have shapes_and_types information: \n"
<< "function_node: " << function_node->ShortDebugString() << "\n"
<< "function input: " << i
<< ", input node's output: " << output_port_num << "\n"
<< "input node: " << input_node->ShortDebugString();
}
}
}
// Replace input nodes with Consts, if values are known. Note that
@ -820,7 +853,7 @@ class SymbolicShapeRefiner {
auto* ic = ctx->inference_context.get();
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
const string& input = function_node->input(i);
const string& node_name = NodeName(input);
const string node_name = NodeName(input);
const NodeDef* input_node = graph_.GetNode(node_name);
if (IsConstant(*input_node)) {
TF_CHECK_OK(
@ -993,9 +1026,11 @@ class SymbolicShapeRefiner {
// Convert all kUnknownDimFromConst to -1 for shape inference.
ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
ic, ctx->input_tensors_as_shapes_to_propagate));
// Notice: UpdateFunction only uses input_tensors_as_shapes, so for function
// nodes, we dont' perform the conversion from TensorProtos to Tensors for
// constant inputs here.
// Note: UpdateFunction uses input_tensors_as_shapes and
// input_tensor_protos (not the Tensor object) for input values.
// so for function nodes, we don't need to convert TensorProtos
// to Tensors here. If the current op is not a function op, we convert
// TensorProtos to Tensors before calling InferShapes.
// Properly handle function nodes.
if (ctx->op_data && ctx->op_data->is_function_op) {

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,130 @@
node {
name: "x"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
dim {
size: 3
}
}
float_val: 2.0
}
}
}
}
node {
name: "y"
op: "_Arg"
attr {
key: "T"
value {
type: DT_RESOURCE
}
}
attr {
key: "_handle_dtypes"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "_handle_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 3
}
}
}
}
}
attr {
key: "index"
value {
i: 0
}
}
}
node {
name: "FunctionWithDtResourceInput"
op: "FunctionWithDtResourceInput"
input: "x"
input: "y"
}
library {
function {
signature {
name: "FunctionWithDtResourceInput"
input_arg {
name: "x"
type: DT_FLOAT
}
input_arg {
name: "y"
type: DT_RESOURCE
}
output_arg {
name: "z1"
type: DT_FLOAT
}
output_arg {
name: "z2"
type: DT_FLOAT
}
}
node_def {
name: "y1"
op: "ReadVariableOp"
input: "y"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node_def {
name: "Add"
op: "Add"
input: "x"
input: "y1:value:0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
ret {
key: "z1"
value: "Add:z:0"
}
ret {
key: "z2"
value: "x"
}
}
}
versions {
producer: 26
min_consumer: 12
}