_Arg node requires _handle_dtypes and _handle_shapes attr when T is DT_RESOURCE.
Otherwise, function node with DT_RESOURCE type input fails to infer output shapes. PiperOrigin-RevId: 308927584 Change-Id: I653b293aaa36ffd4a6bed68e4692398eebd9d578
This commit is contained in:
parent
62856c9366
commit
7028dbefc0
@ -804,14 +804,47 @@ class SymbolicShapeRefiner {
|
||||
int output_port_num = input_tensor.index();
|
||||
AttrValue attr_output_shape;
|
||||
TensorShapeProto proto;
|
||||
const auto& handle = input_ic->output(output_port_num);
|
||||
const auto handle = input_ic->output(output_port_num);
|
||||
input_ic->ShapeHandleToProto(handle, &proto);
|
||||
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
|
||||
NormalizeShapeForOutput(&proto);
|
||||
|
||||
// _Arg op's output shape uses _output_shapes attr.
|
||||
AttrValue output_attr;
|
||||
output_attr.mutable_list()->add_shape()->Swap(&proto);
|
||||
(*fun_node->mutable_attr())["_output_shapes"] = output_attr;
|
||||
|
||||
// If dtype is DT_RESOURCE, ops that read _Arg op use _handle_dtypes and
|
||||
// _handle_shapes attr for its shapes and dtypes.
|
||||
if (fun_input.data_type == DT_RESOURCE) {
|
||||
auto* shapes_and_types =
|
||||
input_ic->output_handle_shapes_and_types(output_port_num);
|
||||
if (shapes_and_types != nullptr && !shapes_and_types->empty()) {
|
||||
AttrValue dtype_attr;
|
||||
AttrValue shape_attr;
|
||||
for (const auto& shape_and_type : *shapes_and_types) {
|
||||
const auto& dtype = shape_and_type.dtype;
|
||||
const auto& shape_handle = shape_and_type.shape;
|
||||
dtype_attr.mutable_list()->add_type(dtype);
|
||||
input_ic->ShapeHandleToProto(
|
||||
shape_handle, shape_attr.mutable_list()->add_shape());
|
||||
}
|
||||
(*fun_node->mutable_attr())["_handle_dtypes"] = dtype_attr;
|
||||
(*fun_node->mutable_attr())["_handle_shapes"] = shape_attr;
|
||||
} else {
|
||||
// Note that we do not return error here, even if the input node does
|
||||
// not have shapes_and_types. Within the function, we cannot infer the
|
||||
// output shape of the DT_RESOURCE input; hence, potentially unknown
|
||||
// shapes/dims in the function output shapes.
|
||||
VLOG(2)
|
||||
<< "A function node (" << function_node->name()
|
||||
<< ") has input with DT_RESOURCE, but the input node does not "
|
||||
<< "have shapes_and_types information: \n"
|
||||
<< "function_node: " << function_node->ShortDebugString() << "\n"
|
||||
<< "function input: " << i
|
||||
<< ", input node's output: " << output_port_num << "\n"
|
||||
<< "input node: " << input_node->ShortDebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Replace input nodes with Consts, if values are known. Note that
|
||||
@ -820,7 +853,7 @@ class SymbolicShapeRefiner {
|
||||
auto* ic = ctx->inference_context.get();
|
||||
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
|
||||
const string& input = function_node->input(i);
|
||||
const string& node_name = NodeName(input);
|
||||
const string node_name = NodeName(input);
|
||||
const NodeDef* input_node = graph_.GetNode(node_name);
|
||||
if (IsConstant(*input_node)) {
|
||||
TF_CHECK_OK(
|
||||
@ -993,9 +1026,11 @@ class SymbolicShapeRefiner {
|
||||
// Convert all kUnknownDimFromConst to -1 for shape inference.
|
||||
ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
|
||||
ic, ctx->input_tensors_as_shapes_to_propagate));
|
||||
// Notice: UpdateFunction only uses input_tensors_as_shapes, so for function
|
||||
// nodes, we dont' perform the conversion from TensorProtos to Tensors for
|
||||
// constant inputs here.
|
||||
// Note: UpdateFunction uses input_tensors_as_shapes and
|
||||
// input_tensor_protos (not the Tensor object) for input values.
|
||||
// so for function nodes, we don't need to convert TensorProtos
|
||||
// to Tensors here. If the current op is not a function op, we convert
|
||||
// TensorProtos to Tensors before calling InferShapes.
|
||||
|
||||
// Properly handle function nodes.
|
||||
if (ctx->op_data && ctx->op_data->is_function_op) {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,130 @@
|
||||
node {
|
||||
name: "x"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
float_val: 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "y"
|
||||
op: "_Arg"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_handle_dtypes"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_handle_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "FunctionWithDtResourceInput"
|
||||
op: "FunctionWithDtResourceInput"
|
||||
input: "x"
|
||||
input: "y"
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "FunctionWithDtResourceInput"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
output_arg {
|
||||
name: "z1"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "z2"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "y1"
|
||||
op: "ReadVariableOp"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "x"
|
||||
input: "y1:value:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "z1"
|
||||
value: "Add:z:0"
|
||||
}
|
||||
ret {
|
||||
key: "z2"
|
||||
value: "x"
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 26
|
||||
min_consumer: 12
|
||||
}
|
Loading…
Reference in New Issue
Block a user