Support constant folding across Arg nodes during shape inference on nested function calls.
PiperOrigin-RevId: 337992497 Change-Id: If480c605fa34def3586a62105c5517c02c2f73e6
This commit is contained in:
parent
969055f71a
commit
f41a9a9335
@ -162,7 +162,8 @@
|
||||
stateful ops.
|
||||
* Added `tf.config.experimental.get_memory_usage` to return total memory
|
||||
usage of the device.
|
||||
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
|
||||
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
|
||||
* Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions.
|
||||
* `tf.data`:
|
||||
* tf.data service:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -123,6 +124,17 @@ bool HasCpuKernel(const Node& node) {
|
||||
.ok();
|
||||
}
|
||||
|
||||
Status GetArgNodeIndex(const Node* node, int num_function_inputs, int* index) {
|
||||
DCHECK(node->IsArg());
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", index));
|
||||
if (*index < 0 || num_function_inputs <= *index) {
|
||||
return errors::Internal(
|
||||
"Function instantiation included invalid input index: ", index,
|
||||
" not in [0, ", num_function_inputs, ").");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Extracts the subgraph ending at 'target_node' that is statically computable
|
||||
// and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
|
||||
// will be set to true.
|
||||
@ -130,7 +142,8 @@ Status ExtractConstantSubgraph(
|
||||
const Node& target_node, const ShapeRefiner& refiner,
|
||||
const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
|
||||
bool* is_constant_graph,
|
||||
std::vector<std::pair<string, Tensor>>* const_inputs) {
|
||||
std::vector<std::pair<string, Tensor>>* const_inputs,
|
||||
InferenceContext* outer_context) {
|
||||
*is_constant_graph = false;
|
||||
std::unordered_set<string> const_inputs_added;
|
||||
|
||||
@ -187,8 +200,9 @@ Status ExtractConstantSubgraph(
|
||||
edges_to_visit.pop_front();
|
||||
Node* current_node = current_edge->src();
|
||||
|
||||
// If the node is stateful, assume the graph is not constant.
|
||||
if (current_node->op_def().is_stateful()) {
|
||||
// If the node is stateful, assume the graph is not constant unless it is
|
||||
// an Arg node which is handled later on.
|
||||
if (!current_node->IsArg() && current_node->op_def().is_stateful()) {
|
||||
*is_constant_graph = false;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -223,9 +237,32 @@ Status ExtractConstantSubgraph(
|
||||
}
|
||||
|
||||
// If there is nothing more to recurse down, see if
|
||||
// the generator node is a constant.
|
||||
// the generator node is a constant or an Arg node whose value is available
|
||||
// in the `outer_context`.
|
||||
if (current_node->num_inputs() == 0) {
|
||||
if (!current_node->IsConstant()) {
|
||||
if (outer_context && current_node->IsArg()) {
|
||||
const string& tensor_name =
|
||||
strings::StrCat(current_node->name(), ":", 0);
|
||||
// If we do not already have a constant Tensor for this Arg try to
|
||||
// fetch it from the outer context.
|
||||
if (const_inputs_added.count(tensor_name) == 0) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetArgNodeIndex(
|
||||
current_node, outer_context->num_inputs(), &index));
|
||||
const Tensor* const_tensor = outer_context->input_tensor(index);
|
||||
if (const_tensor) {
|
||||
const_inputs->emplace_back(tensor_name, *const_tensor);
|
||||
const_inputs_added.insert(tensor_name);
|
||||
} else {
|
||||
// Request a constant value for this Arg. If that is statically
|
||||
// computable, shape refiner will re-run the shape inference for
|
||||
// this function with this tensor's value.
|
||||
outer_context->request_input_tensor(index);
|
||||
*is_constant_graph = false;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
} else if (!current_node->IsConstant()) {
|
||||
// Generator node is not a constant, so subgraph is not
|
||||
// constant.
|
||||
*is_constant_graph = false;
|
||||
@ -314,7 +351,8 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
|
||||
Tensor* result, GraphRunner* graph_runner,
|
||||
std::unordered_map<string, Tensor>* cached_values,
|
||||
int64 max_cached_value_size,
|
||||
bool disable_constant_propagation) {
|
||||
bool disable_constant_propagation,
|
||||
InferenceContext* outer_context) {
|
||||
*evaluated = false;
|
||||
const Node* src = tensor.node;
|
||||
|
||||
@ -326,6 +364,22 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
|
||||
}
|
||||
}
|
||||
|
||||
// If the source node is an Arg return its value, if available in the outer
|
||||
// context.
|
||||
if (src->IsArg() && outer_context) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetArgNodeIndex(src, outer_context->num_inputs(), &index));
|
||||
const Tensor* const_tensor = outer_context->input_tensor(index);
|
||||
if (const_tensor) {
|
||||
*evaluated = true;
|
||||
*result = *(outer_context->input_tensor(index));
|
||||
} else {
|
||||
outer_context->request_input_tensor(index);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (disable_constant_propagation) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -339,7 +393,7 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
|
||||
std::vector<std::pair<string, Tensor>> const_inputs;
|
||||
TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
|
||||
&subgraph, &is_constant_graph,
|
||||
&const_inputs));
|
||||
&const_inputs, outer_context));
|
||||
if (!is_constant_graph) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -53,13 +53,17 @@ class Tensor;
|
||||
// result size to cache.
|
||||
// disable_constant_propagation - if true, only Const node values will be
|
||||
// returned.
|
||||
// outer_context - optional. The InferenceContext for the call node if inside
|
||||
// a nested function. This is useful for doing constant propagation across
|
||||
// Arg nodes.
|
||||
Status EvaluateConstantTensor(
|
||||
OutputTensor tensor, const ShapeRefiner& refiner,
|
||||
const OpRegistryInterface& ops, int32 graph_def_version, bool* evaluated,
|
||||
Tensor* result, GraphRunner* graph_runner = nullptr,
|
||||
std::unordered_map<string, Tensor>* cached_values = nullptr,
|
||||
int64 max_cached_value_size = 1024,
|
||||
bool disable_constant_propagation = false);
|
||||
bool disable_constant_propagation = false,
|
||||
shape_inference::InferenceContext* outer_context = nullptr);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
@ -59,13 +60,15 @@ namespace {
|
||||
constexpr char kArgOp[] = "_Arg";
|
||||
constexpr char kRetvalOp[] = "_Retval";
|
||||
|
||||
} // namespace
|
||||
|
||||
// Runs shape inference for the given node using the given ShapeRefiner.
|
||||
// The node must be a sub-node of a function node and the outer_context is
|
||||
// the inference context of that function node in the outer graph.
|
||||
Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
|
||||
InferenceContext* outer_context) {
|
||||
TF_RETURN_IF_ERROR(refiner->AddNode(node));
|
||||
InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
|
||||
Status ShapeRefiner::InferShapesForFunctionSubNode(
|
||||
const Node* node, InferenceContext* outer_context) {
|
||||
TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context));
|
||||
InferenceContext* node_context = CHECK_NOTNULL(GetContext(node));
|
||||
|
||||
if (StringPiece(node->type_string()) == kArgOp) {
|
||||
// Handle special node: function input.
|
||||
@ -126,8 +129,6 @@ Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(cwhipkey): When an inference context inside function has
|
||||
// requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
|
||||
// set when input(i) is an _Arg op, then this request should propagate to
|
||||
@ -167,8 +168,8 @@ Status ShapeRefiner::InferShapesForFunction(
|
||||
auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
|
||||
&inference_status](const Node* node) {
|
||||
if (!inference_status.ok()) return;
|
||||
inference_status = InferShapesForFunctionSubNode(
|
||||
node, this, outer_context->get_context());
|
||||
inference_status =
|
||||
InferShapesForFunctionSubNode(node, outer_context->get_context());
|
||||
function_nodes.insert(node);
|
||||
};
|
||||
|
||||
@ -187,6 +188,11 @@ Status ShapeRefiner::InferShapesForFunction(
|
||||
}
|
||||
|
||||
Status ShapeRefiner::AddNode(const Node* node) {
|
||||
return AddNodeInternal(node, /*outer_context=*/nullptr);
|
||||
}
|
||||
|
||||
Status ShapeRefiner::AddNodeInternal(
|
||||
const Node* node, shape_inference::InferenceContext* outer_context) {
|
||||
// Create the inference context for this node with the existing input shapes.
|
||||
std::unique_ptr<InferenceContext> ic(new InferenceContext(
|
||||
graph_def_version_, node->def(), node->op_def(),
|
||||
@ -240,7 +246,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
new ExtendedInferenceContext(std::move(ic), node));
|
||||
|
||||
// Run the shape inference function, and return if there was an error.
|
||||
TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
|
||||
TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context));
|
||||
|
||||
// Store the resulting context object in the map.
|
||||
node_to_context_[node].swap(ec);
|
||||
@ -385,25 +391,25 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
|
||||
return RunShapeFn(node, op_reg_data, node_ext_context);
|
||||
}
|
||||
|
||||
Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
|
||||
int dst_idx, bool* evaluated,
|
||||
Tensor* result) {
|
||||
Status ShapeRefiner::EvaluateConstantTensorForEdge(
|
||||
const Node* node, int dst_idx, bool* evaluated, Tensor* result,
|
||||
InferenceContext* outer_context) {
|
||||
*evaluated = false;
|
||||
const Edge* input_edge;
|
||||
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
|
||||
OutputTensor tensor(input_edge->src(), input_edge->src_output());
|
||||
return EvaluateConstantTensor(tensor, *this, *ops_registry_,
|
||||
graph_def_version_, evaluated, result,
|
||||
&graph_runner_, &const_tensor_map_,
|
||||
kMaxTensorSize, disable_constant_propagation_);
|
||||
return EvaluateConstantTensor(
|
||||
tensor, *this, *ops_registry_, graph_def_version_, evaluated, result,
|
||||
&graph_runner_, &const_tensor_map_, kMaxTensorSize,
|
||||
disable_constant_propagation_, outer_context);
|
||||
}
|
||||
|
||||
Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
|
||||
int dst_idx, bool* evaluated,
|
||||
int64* result) {
|
||||
Status ShapeRefiner::EvaluateConstantIntScalarEdge(
|
||||
const Node* node, int dst_idx, bool* evaluated, int64* result,
|
||||
shape_inference::InferenceContext* outer_context) {
|
||||
Tensor scalar;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar));
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated,
|
||||
&scalar, outer_context));
|
||||
if (*evaluated) {
|
||||
if (scalar.NumElements() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
@ -424,9 +430,9 @@ Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
const Node* node, int dst_idx,
|
||||
ShapeHandle* result) {
|
||||
Status ShapeRefiner::ConstantPartialShape(
|
||||
InferenceContext* target_context, const Node* node, int dst_idx,
|
||||
ShapeHandle* result, shape_inference::InferenceContext* outer_context) {
|
||||
const Edge* input_edge;
|
||||
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
|
||||
|
||||
@ -437,8 +443,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
if (src_context->Value(src_context->Rank(src_shape)) == 0) {
|
||||
Tensor t;
|
||||
bool evaluated = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
|
||||
&t, outer_context));
|
||||
if (!evaluated) {
|
||||
return errors::InvalidArgument(
|
||||
"Received a shape scalar with unknown static value. A static value "
|
||||
@ -471,7 +477,9 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
// a float.
|
||||
Tensor t;
|
||||
bool evaluated = false;
|
||||
if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t).ok()) {
|
||||
if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t,
|
||||
outer_context)
|
||||
.ok()) {
|
||||
if (evaluated &&
|
||||
target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
|
||||
return Status::OK();
|
||||
@ -481,7 +489,7 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
// Then try to infer partial shape from the input to the cast tensor.
|
||||
ShapeHandle pre_cast_shape;
|
||||
if (!ConstantPartialShape(target_context, input_edge->src(), 0,
|
||||
&pre_cast_shape)
|
||||
&pre_cast_shape, outer_context)
|
||||
.ok()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
target_context->MakeShapeFromTensor(nullptr, src_shape, result));
|
||||
@ -510,8 +518,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
for (int i = 0; i < src_context->num_inputs(); ++i) {
|
||||
int64 size;
|
||||
bool evaluated;
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i,
|
||||
&evaluated, &size));
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(
|
||||
input_edge->src(), i, &evaluated, &size, outer_context));
|
||||
if (evaluated) {
|
||||
dims.push_back(size < 0 ? target_context->UnknownDim()
|
||||
: target_context->MakeDim(size));
|
||||
@ -531,7 +539,7 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
if (i == concat_dim) continue;
|
||||
ShapeHandle sub_result;
|
||||
TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
|
||||
i, &sub_result));
|
||||
i, &sub_result, outer_context));
|
||||
if (!target_context->RankKnown(sub_result)) {
|
||||
// Failed to evaluate. Treat the output as completely unknown.
|
||||
// TODO(cwhipkey): we could rely on all inputs being the same rank, so
|
||||
@ -543,8 +551,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
target_context->Concatenate(*result, sub_result, result));
|
||||
}
|
||||
} else if (src_op == "StridedSlice") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PartialStridedSliceShape(input_edge->src(), src_context, result));
|
||||
TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context,
|
||||
result, outer_context));
|
||||
} else if (src_op == "VariableShape") {
|
||||
auto* handle_data = src_context->input_handle_shapes_and_types(0);
|
||||
if (handle_data != nullptr && !handle_data->empty()) {
|
||||
@ -555,17 +563,17 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
} else {
|
||||
Tensor t;
|
||||
bool evaluated = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
|
||||
&t, outer_context));
|
||||
TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
|
||||
evaluated ? &t : nullptr, src_shape, result));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
||||
InferenceContext* ctx,
|
||||
ShapeHandle* result) {
|
||||
Status ShapeRefiner::PartialStridedSliceShape(
|
||||
Node* slice_node, InferenceContext* ctx, ShapeHandle* result,
|
||||
shape_inference::InferenceContext* outer_context) {
|
||||
// Only attempt to evaluate if begin/end/strides all are scalars.
|
||||
for (int i = 1; i <= 3; ++i) {
|
||||
ShapeHandle input_shape = ctx->input(i);
|
||||
@ -600,8 +608,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
||||
if (begin_mask == 1) {
|
||||
begin = 0;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin));
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated,
|
||||
&begin, outer_context));
|
||||
if (!evaluated) {
|
||||
*result = ctx->UnknownShape();
|
||||
return Status::OK();
|
||||
@ -612,8 +620,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
||||
if (end_mask == 1) {
|
||||
end = std::numeric_limits<int64>::max();
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end));
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated,
|
||||
&end, outer_context));
|
||||
if (!evaluated) {
|
||||
*result = ctx->UnknownShape();
|
||||
return Status::OK();
|
||||
@ -621,8 +629,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
||||
}
|
||||
|
||||
int64 stride;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride));
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated,
|
||||
&stride, outer_context));
|
||||
if (!evaluated) {
|
||||
*result = ctx->UnknownShape();
|
||||
return Status::OK();
|
||||
@ -630,14 +638,16 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
||||
|
||||
// Apply stride to input interpreted as a partial shape.
|
||||
ShapeHandle input;
|
||||
TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConstantPartialShape(ctx, slice_node, 0, &input, outer_context));
|
||||
TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::RunShapeFn(const Node* node,
|
||||
const OpRegistrationData* op_reg_data,
|
||||
ExtendedInferenceContext* ec) {
|
||||
ExtendedInferenceContext* ec,
|
||||
InferenceContext* outer_context) {
|
||||
// This will be filled in with real data in a second pass.
|
||||
std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
|
||||
std::vector<Tensor> real_tensors(node->num_inputs());
|
||||
@ -719,8 +729,8 @@ Status ShapeRefiner::RunShapeFn(const Node* node,
|
||||
|
||||
Tensor result;
|
||||
bool evaluated = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(
|
||||
node, i, &evaluated, &result, outer_context));
|
||||
if (evaluated) {
|
||||
real_tensors[i] = result;
|
||||
input_tensors[i] = &real_tensors[i];
|
||||
@ -736,7 +746,7 @@ Status ShapeRefiner::RunShapeFn(const Node* node,
|
||||
input_tensors_as_shapes.resize(i + 1);
|
||||
}
|
||||
ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
|
||||
TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context));
|
||||
input_tensors_as_shapes[i] = s;
|
||||
rerun_shape_fn = true;
|
||||
}
|
||||
|
@ -184,17 +184,56 @@ class ShapeRefiner {
|
||||
AttrSlice attributes,
|
||||
ExtendedInferenceContext* outer_context);
|
||||
|
||||
// Performs shape inference for a node inside a function.
|
||||
//
|
||||
// 'outer_context' is the 'InferenceContext' for the function's call op.
|
||||
Status InferShapesForFunctionSubNode(
|
||||
const Node* node, shape_inference::InferenceContext* outer_context);
|
||||
|
||||
// Performs validation of 'node' and runs 'node's shape function,
|
||||
// storing its shape outputs.
|
||||
//
|
||||
// All inputs of 'node' must be added to ShapeRefiner prior to
|
||||
// adding 'node'.
|
||||
//
|
||||
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||
// by requesting the constant of value of the incoming tensor from the
|
||||
// 'outer_context'.
|
||||
//
|
||||
// Returns an error if:
|
||||
// - the shape function for 'node' was not registered.
|
||||
// - 'node' was added before its inputs.
|
||||
// - The shape inference function returns an error.
|
||||
Status AddNodeInternal(const Node* node,
|
||||
shape_inference::InferenceContext* outer_context);
|
||||
|
||||
// Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
|
||||
// value can be evaluated, 'evaluated' is set to true and the value returned
|
||||
// in 'result'. Otherwise 'evaluated' is set to false.
|
||||
Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
|
||||
bool* evaluated, Tensor* result);
|
||||
//
|
||||
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||
// by requesting the constant of value of the incoming tensor from the
|
||||
// 'outer_context'.
|
||||
Status EvaluateConstantTensorForEdge(
|
||||
const Node* node, int dst_idx, bool* evaluated, Tensor* result,
|
||||
shape_inference::InferenceContext* outer_context);
|
||||
|
||||
// Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
|
||||
// tensors. The caller is responsible for checking that the specified edge is
|
||||
// scalar and int32 or int64.
|
||||
Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx,
|
||||
bool* evaluated, int64* result);
|
||||
//
|
||||
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||
// by requesting the constant of value of the incoming tensor from the
|
||||
// 'outer_context'.
|
||||
Status EvaluateConstantIntScalarEdge(
|
||||
const Node* node, int dst_idx, bool* evaluated, int64* result,
|
||||
shape_inference::InferenceContext* outer_context);
|
||||
|
||||
// This function tries to materialize as much information about the 'node''s
|
||||
// dst_idx input as a statically computable shape, and the result may be
|
||||
@ -217,17 +256,39 @@ class ShapeRefiner {
|
||||
//
|
||||
// <target_context> is used when creating new DimensionHandle and ShapeHandle
|
||||
// objects.
|
||||
//
|
||||
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||
// by requesting the constant of value of the incoming tensor from the
|
||||
// 'outer_context'.
|
||||
Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
|
||||
const Node* node, int dst_idx,
|
||||
shape_inference::ShapeHandle* result);
|
||||
shape_inference::ShapeHandle* result,
|
||||
shape_inference::InferenceContext* outer_context);
|
||||
|
||||
// Implementation of ConstantPartialShape for StridedSlice nodes.
|
||||
Status PartialStridedSliceShape(Node* slice_node,
|
||||
shape_inference::InferenceContext* ctx,
|
||||
shape_inference::ShapeHandle* result);
|
||||
//
|
||||
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||
// by requesting the constant of value of the incoming tensor from the
|
||||
// 'outer_context'.
|
||||
Status PartialStridedSliceShape(
|
||||
Node* slice_node, shape_inference::InferenceContext* ctx,
|
||||
shape_inference::ShapeHandle* result,
|
||||
shape_inference::InferenceContext* outer_context);
|
||||
|
||||
// Runs the shape function registered for the node's op type.
|
||||
//
|
||||
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||
// by requesting the constant of value of the incoming tensor from the
|
||||
// 'outer_context'.
|
||||
Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
|
||||
ExtendedInferenceContext* ec);
|
||||
ExtendedInferenceContext* ec,
|
||||
shape_inference::InferenceContext* outer_context = nullptr);
|
||||
|
||||
int32 graph_def_version_;
|
||||
const OpRegistryInterface* const ops_registry_;
|
||||
|
@ -719,7 +719,7 @@ Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
||||
ShapeHandle input_shape;
|
||||
TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
|
||||
|
||||
requested_input_tensor_as_partial_shape_[input_idx] = true;
|
||||
request_input_tensor_as_partial_shape(input_idx);
|
||||
const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
|
||||
if (input_idx < input_tensors_as_shapes_size &&
|
||||
input_tensors_as_shapes_[input_idx].IsSet() &&
|
||||
@ -738,7 +738,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
|
||||
ShapeHandle input_shape;
|
||||
TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
|
||||
|
||||
requested_input_tensor_as_partial_shape_[input_idx] = true;
|
||||
request_input_tensor_as_partial_shape(input_idx);
|
||||
const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
|
||||
if (input_idx < input_tensors_as_shapes_size &&
|
||||
input_tensors_as_shapes_[input_idx].IsSet() &&
|
||||
|
@ -268,15 +268,31 @@ class InferenceContext {
|
||||
// not available at the time of shape inference.
|
||||
const Tensor* input_tensor(int idx) {
|
||||
// Mark that this idx was requested.
|
||||
requested_input_tensor_[idx] = true;
|
||||
request_input_tensor(idx);
|
||||
return input_tensors_[idx];
|
||||
}
|
||||
|
||||
// Notifies the shape refiner that the value of the tensor at index <idx>
|
||||
// is needed. The shape refiner tries to statically compute this tensor,
|
||||
// and if successful re-runs the shape function with this tensor available
|
||||
// in the call to 'input_tensor(idx)'.
|
||||
void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; }
|
||||
|
||||
// Returns true iff input_tensor(idx) was called by the shape function.
|
||||
bool requested_input_tensor(int idx) const {
|
||||
return requested_input_tensor_[idx];
|
||||
}
|
||||
|
||||
// Notifies the shape refiner that the value of the tensor at index <idx>
|
||||
// as a partial shape is needed. The shape refiner tries to statically compute
|
||||
// this, and if successful re-runs the shape function with the
|
||||
// computed PartialTensorShape available in the call to
|
||||
// 'MakeShapeFromShapeTensor(idx, handle)' or
|
||||
// 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'.
|
||||
void request_input_tensor_as_partial_shape(int idx) {
|
||||
requested_input_tensor_as_partial_shape_[idx] = true;
|
||||
}
|
||||
|
||||
// Returns true if MakeShapeFromInputTensor was called but the constant
|
||||
// input_tensor was not present.
|
||||
bool requested_input_tensor_as_partial_shape(int idx) const {
|
||||
|
@ -4297,6 +4297,92 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegex(TypeError, 'missing required arguments: y'):
|
||||
foo.add(2) # pylint: disable=no-value-for-parameter
|
||||
|
||||
def testShapeInferencePropagateConstNestedStack(self):
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
|
||||
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||
])
|
||||
def f(x, s):
|
||||
old_shape = array_ops.shape(x)
|
||||
new_shape = array_ops.stack([old_shape[0], s], axis=0)
|
||||
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
|
||||
return y
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
|
||||
])
|
||||
def g(x):
|
||||
y = f(x, s=5)
|
||||
assert y.shape.as_list() == [3, 5], y.shape.as_list()
|
||||
return y
|
||||
|
||||
self.assertAllEqual(
|
||||
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
|
||||
|
||||
def testShapeInferencePropagateConstNestedUnstackStack(self):
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
|
||||
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||
])
|
||||
def f(x, s):
|
||||
s0, _ = array_ops.unstack(array_ops.shape(x), axis=0)
|
||||
new_shape = array_ops.stack([s0, s], axis=0)
|
||||
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
|
||||
return y
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
|
||||
])
|
||||
def g(x):
|
||||
y = f(x, s=5)
|
||||
assert y.shape.as_list() == [3, 5], y.shape.as_list()
|
||||
return y
|
||||
|
||||
self.assertAllEqual(
|
||||
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
|
||||
|
||||
def testShapeInferencePropagateConstNestedConcat(self):
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||
])
|
||||
def f(d1, d2, d3):
|
||||
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
|
||||
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
|
||||
return y
|
||||
|
||||
@def_function.function()
|
||||
def g():
|
||||
y = f(1, 2, 3)
|
||||
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
|
||||
return y
|
||||
|
||||
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
|
||||
|
||||
def testShapeInferencePropagateConstDoubleNested(self):
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||
])
|
||||
def f(d1, d2, d3):
|
||||
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
|
||||
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
|
||||
return y
|
||||
|
||||
@def_function.function()
|
||||
def g():
|
||||
y = def_function.function(f)(1, 2, 3)
|
||||
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
|
||||
return y
|
||||
|
||||
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
|
||||
|
||||
|
||||
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -2007,7 +2007,8 @@ class CentralCropTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(y.op.name.startswith("central_crop"))
|
||||
|
||||
|
||||
class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
|
||||
class PadToBoundingBoxTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _PadToBoundingBox(self, x, offset_height, offset_width, target_height,
|
||||
target_width, use_tensor_inputs):
|
||||
@ -2172,7 +2173,10 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
|
||||
"inner 3 dims of \\'image.shape\\' must be > 0",
|
||||
use_tensor_inputs_options=[True])
|
||||
|
||||
def testBadParams(self):
|
||||
def testBadParamsScalarInputs(self):
|
||||
# In this test, inputs do not get converted to tensors before calling the
|
||||
# tf.function. The error message here is raised in python
|
||||
# since the python function has direct access to the scalars.
|
||||
x_shape = [3, 3, 1]
|
||||
x = np.zeros(x_shape)
|
||||
|
||||
@ -2187,9 +2191,49 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
|
||||
"height must be <= target - offset"),
|
||||
(0, 2, 4, 4,
|
||||
"width must be <= target - offset"))
|
||||
|
||||
for config_item in test_config:
|
||||
self._assertRaises(x, x_shape, *config_item)
|
||||
self._assertRaises(
|
||||
x, x_shape, *config_item, use_tensor_inputs_options=[False])
|
||||
|
||||
def testBadParamsTensorInputsEager(self):
|
||||
# In this test inputs get converted to EagerTensors before calling the
|
||||
# tf.function. The error message here is raised in python
|
||||
# since the python function has direct access to the tensor's values.
|
||||
with context.eager_mode():
|
||||
x_shape = [3, 3, 1]
|
||||
x = np.zeros(x_shape)
|
||||
|
||||
# Each line is a test configuration:
|
||||
# offset_height, offset_width, target_height, target_width, err_msg
|
||||
test_config = (
|
||||
(-1, 0, 4, 4,
|
||||
"offset_height must be >= 0"),
|
||||
(0, -1, 4, 4,
|
||||
"offset_width must be >= 0"),
|
||||
(2, 0, 4, 4,
|
||||
"height must be <= target - offset"),
|
||||
(0, 2, 4, 4,
|
||||
"width must be <= target - offset"))
|
||||
for config_item in test_config:
|
||||
self._assertRaises(
|
||||
x, x_shape, *config_item, use_tensor_inputs_options=[True])
|
||||
|
||||
@parameterized.named_parameters([("OffsetHeight", (-1, 0, 4, 4)),
|
||||
("OffsetWidth", (0, -1, 4, 4)),
|
||||
("Height", (2, 0, 4, 4)),
|
||||
("Width", (0, 2, 4, 4))])
|
||||
def testBadParamsTensorInputsGraph(self, config):
|
||||
# In this test inputs get converted to tensors before calling the
|
||||
# tf.function. The error message here is raised during shape inference.
|
||||
with context.graph_mode():
|
||||
x_shape = [3, 3, 1]
|
||||
x = np.zeros(x_shape)
|
||||
self._assertRaises(
|
||||
x,
|
||||
x_shape,
|
||||
*config,
|
||||
"Paddings must be non-negative",
|
||||
use_tensor_inputs_options=[True])
|
||||
|
||||
def testNameScope(self):
|
||||
# Testing name scope requires a graph.
|
||||
|
Loading…
Reference in New Issue
Block a user