Support constant folding across Arg nodes during shape inference on nested function calls.

PiperOrigin-RevId: 337992497
Change-Id: If480c605fa34def3586a62105c5517c02c2f73e6
This commit is contained in:
Saurabh Saxena 2020-10-19 21:22:06 -07:00 committed by TensorFlower Gardener
parent 969055f71a
commit f41a9a9335
9 changed files with 350 additions and 74 deletions

View File

@ -162,7 +162,8 @@
stateful ops.
* Added `tf.config.experimental.get_memory_usage` to return total memory
usage of the device.
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
* Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions.
* `tf.data`:
* tf.data service:
* Added new `tf.data.experimental.service.register_dataset` and

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
@ -123,6 +124,17 @@ bool HasCpuKernel(const Node& node) {
.ok();
}
Status GetArgNodeIndex(const Node* node, int num_function_inputs, int* index) {
DCHECK(node->IsArg());
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", index));
if (*index < 0 || num_function_inputs <= *index) {
return errors::Internal(
"Function instantiation included invalid input index: ", index,
" not in [0, ", num_function_inputs, ").");
}
return Status::OK();
}
// Extracts the subgraph ending at 'target_node' that is statically computable
// and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
// will be set to true.
@ -130,7 +142,8 @@ Status ExtractConstantSubgraph(
const Node& target_node, const ShapeRefiner& refiner,
const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
bool* is_constant_graph,
std::vector<std::pair<string, Tensor>>* const_inputs) {
std::vector<std::pair<string, Tensor>>* const_inputs,
InferenceContext* outer_context) {
*is_constant_graph = false;
std::unordered_set<string> const_inputs_added;
@ -187,8 +200,9 @@ Status ExtractConstantSubgraph(
edges_to_visit.pop_front();
Node* current_node = current_edge->src();
// If the node is stateful, assume the graph is not constant.
if (current_node->op_def().is_stateful()) {
// If the node is stateful, assume the graph is not constant unless it is
// an Arg node which is handled later on.
if (!current_node->IsArg() && current_node->op_def().is_stateful()) {
*is_constant_graph = false;
return Status::OK();
}
@ -223,9 +237,32 @@ Status ExtractConstantSubgraph(
}
// If there is nothing more to recurse down, see if
// the generator node is a constant.
// the generator node is a constant or an Arg node whose value is available
// in the `outer_context`.
if (current_node->num_inputs() == 0) {
if (!current_node->IsConstant()) {
if (outer_context && current_node->IsArg()) {
const string& tensor_name =
strings::StrCat(current_node->name(), ":", 0);
// If we do not already have a constant Tensor for this Arg try to
// fetch it from the outer context.
if (const_inputs_added.count(tensor_name) == 0) {
int index;
TF_RETURN_IF_ERROR(GetArgNodeIndex(
current_node, outer_context->num_inputs(), &index));
const Tensor* const_tensor = outer_context->input_tensor(index);
if (const_tensor) {
const_inputs->emplace_back(tensor_name, *const_tensor);
const_inputs_added.insert(tensor_name);
} else {
// Request a constant value for this Arg. If that is statically
// computable, shape refiner will re-run the shape inference for
// this function with this tensor's value.
outer_context->request_input_tensor(index);
*is_constant_graph = false;
return Status::OK();
}
}
} else if (!current_node->IsConstant()) {
// Generator node is not a constant, so subgraph is not
// constant.
*is_constant_graph = false;
@ -314,7 +351,8 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
Tensor* result, GraphRunner* graph_runner,
std::unordered_map<string, Tensor>* cached_values,
int64 max_cached_value_size,
bool disable_constant_propagation) {
bool disable_constant_propagation,
InferenceContext* outer_context) {
*evaluated = false;
const Node* src = tensor.node;
@ -326,6 +364,22 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
}
}
// If the source node is an Arg return its value, if available in the outer
// context.
if (src->IsArg() && outer_context) {
int index;
TF_RETURN_IF_ERROR(
GetArgNodeIndex(src, outer_context->num_inputs(), &index));
const Tensor* const_tensor = outer_context->input_tensor(index);
if (const_tensor) {
*evaluated = true;
*result = *(outer_context->input_tensor(index));
} else {
outer_context->request_input_tensor(index);
}
return Status::OK();
}
if (disable_constant_propagation) {
return Status::OK();
}
@ -339,7 +393,7 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
std::vector<std::pair<string, Tensor>> const_inputs;
TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
&subgraph, &is_constant_graph,
&const_inputs));
&const_inputs, outer_context));
if (!is_constant_graph) {
return Status::OK();
}

View File

@ -53,13 +53,17 @@ class Tensor;
// result size to cache.
// disable_constant_propagation - if true, only Const node values will be
// returned.
// outer_context - optional. The InferenceContext for the call node if inside
// a nested function. This is useful for doing constant propagation across
// Arg nodes.
Status EvaluateConstantTensor(
OutputTensor tensor, const ShapeRefiner& refiner,
const OpRegistryInterface& ops, int32 graph_def_version, bool* evaluated,
Tensor* result, GraphRunner* graph_runner = nullptr,
std::unordered_map<string, Tensor>* cached_values = nullptr,
int64 max_cached_value_size = 1024,
bool disable_constant_propagation = false);
bool disable_constant_propagation = false,
shape_inference::InferenceContext* outer_context = nullptr);
} // namespace tensorflow

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@ -59,13 +60,15 @@ namespace {
constexpr char kArgOp[] = "_Arg";
constexpr char kRetvalOp[] = "_Retval";
} // namespace
// Runs shape inference for the given node using the given ShapeRefiner.
// The node must be a sub-node of a function node and the outer_context is
// the inference context of that function node in the outer graph.
Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
InferenceContext* outer_context) {
TF_RETURN_IF_ERROR(refiner->AddNode(node));
InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
Status ShapeRefiner::InferShapesForFunctionSubNode(
const Node* node, InferenceContext* outer_context) {
TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context));
InferenceContext* node_context = CHECK_NOTNULL(GetContext(node));
if (StringPiece(node->type_string()) == kArgOp) {
// Handle special node: function input.
@ -126,8 +129,6 @@ Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
return Status::OK();
}
} // namespace
// TODO(cwhipkey): When an inference context inside function has
// requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
// set when input(i) is an _Arg op, then this request should propagate to
@ -167,8 +168,8 @@ Status ShapeRefiner::InferShapesForFunction(
auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
&inference_status](const Node* node) {
if (!inference_status.ok()) return;
inference_status = InferShapesForFunctionSubNode(
node, this, outer_context->get_context());
inference_status =
InferShapesForFunctionSubNode(node, outer_context->get_context());
function_nodes.insert(node);
};
@ -187,6 +188,11 @@ Status ShapeRefiner::InferShapesForFunction(
}
Status ShapeRefiner::AddNode(const Node* node) {
return AddNodeInternal(node, /*outer_context=*/nullptr);
}
Status ShapeRefiner::AddNodeInternal(
const Node* node, shape_inference::InferenceContext* outer_context) {
// Create the inference context for this node with the existing input shapes.
std::unique_ptr<InferenceContext> ic(new InferenceContext(
graph_def_version_, node->def(), node->op_def(),
@ -240,7 +246,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
new ExtendedInferenceContext(std::move(ic), node));
// Run the shape inference function, and return if there was an error.
TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context));
// Store the resulting context object in the map.
node_to_context_[node].swap(ec);
@ -385,25 +391,25 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
return RunShapeFn(node, op_reg_data, node_ext_context);
}
Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
int dst_idx, bool* evaluated,
Tensor* result) {
Status ShapeRefiner::EvaluateConstantTensorForEdge(
const Node* node, int dst_idx, bool* evaluated, Tensor* result,
InferenceContext* outer_context) {
*evaluated = false;
const Edge* input_edge;
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
OutputTensor tensor(input_edge->src(), input_edge->src_output());
return EvaluateConstantTensor(tensor, *this, *ops_registry_,
graph_def_version_, evaluated, result,
&graph_runner_, &const_tensor_map_,
kMaxTensorSize, disable_constant_propagation_);
return EvaluateConstantTensor(
tensor, *this, *ops_registry_, graph_def_version_, evaluated, result,
&graph_runner_, &const_tensor_map_, kMaxTensorSize,
disable_constant_propagation_, outer_context);
}
Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
int dst_idx, bool* evaluated,
int64* result) {
Status ShapeRefiner::EvaluateConstantIntScalarEdge(
const Node* node, int dst_idx, bool* evaluated, int64* result,
shape_inference::InferenceContext* outer_context) {
Tensor scalar;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar));
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated,
&scalar, outer_context));
if (*evaluated) {
if (scalar.NumElements() != 1) {
return errors::InvalidArgument(
@ -424,9 +430,9 @@ Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
return Status::OK();
}
Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
const Node* node, int dst_idx,
ShapeHandle* result) {
Status ShapeRefiner::ConstantPartialShape(
InferenceContext* target_context, const Node* node, int dst_idx,
ShapeHandle* result, shape_inference::InferenceContext* outer_context) {
const Edge* input_edge;
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
@ -437,8 +443,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
if (src_context->Value(src_context->Rank(src_shape)) == 0) {
Tensor t;
bool evaluated = false;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
&t, outer_context));
if (!evaluated) {
return errors::InvalidArgument(
"Received a shape scalar with unknown static value. A static value "
@ -471,7 +477,9 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
// a float.
Tensor t;
bool evaluated = false;
if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t).ok()) {
if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t,
outer_context)
.ok()) {
if (evaluated &&
target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
return Status::OK();
@ -481,7 +489,7 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
// Then try to infer partial shape from the input to the cast tensor.
ShapeHandle pre_cast_shape;
if (!ConstantPartialShape(target_context, input_edge->src(), 0,
&pre_cast_shape)
&pre_cast_shape, outer_context)
.ok()) {
TF_RETURN_IF_ERROR(
target_context->MakeShapeFromTensor(nullptr, src_shape, result));
@ -510,8 +518,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
for (int i = 0; i < src_context->num_inputs(); ++i) {
int64 size;
bool evaluated;
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i,
&evaluated, &size));
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(
input_edge->src(), i, &evaluated, &size, outer_context));
if (evaluated) {
dims.push_back(size < 0 ? target_context->UnknownDim()
: target_context->MakeDim(size));
@ -531,7 +539,7 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
if (i == concat_dim) continue;
ShapeHandle sub_result;
TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
i, &sub_result));
i, &sub_result, outer_context));
if (!target_context->RankKnown(sub_result)) {
// Failed to evaluate. Treat the output as completely unknown.
// TODO(cwhipkey): we could rely on all inputs being the same rank, so
@ -543,8 +551,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
target_context->Concatenate(*result, sub_result, result));
}
} else if (src_op == "StridedSlice") {
TF_RETURN_IF_ERROR(
PartialStridedSliceShape(input_edge->src(), src_context, result));
TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context,
result, outer_context));
} else if (src_op == "VariableShape") {
auto* handle_data = src_context->input_handle_shapes_and_types(0);
if (handle_data != nullptr && !handle_data->empty()) {
@ -555,17 +563,17 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
} else {
Tensor t;
bool evaluated = false;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
&t, outer_context));
TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
evaluated ? &t : nullptr, src_shape, result));
}
return Status::OK();
}
Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
InferenceContext* ctx,
ShapeHandle* result) {
Status ShapeRefiner::PartialStridedSliceShape(
Node* slice_node, InferenceContext* ctx, ShapeHandle* result,
shape_inference::InferenceContext* outer_context) {
// Only attempt to evaluate if begin/end/strides all are scalars.
for (int i = 1; i <= 3; ++i) {
ShapeHandle input_shape = ctx->input(i);
@ -600,8 +608,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
if (begin_mask == 1) {
begin = 0;
} else {
TF_RETURN_IF_ERROR(
EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin));
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated,
&begin, outer_context));
if (!evaluated) {
*result = ctx->UnknownShape();
return Status::OK();
@ -612,8 +620,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
if (end_mask == 1) {
end = std::numeric_limits<int64>::max();
} else {
TF_RETURN_IF_ERROR(
EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end));
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated,
&end, outer_context));
if (!evaluated) {
*result = ctx->UnknownShape();
return Status::OK();
@ -621,8 +629,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
}
int64 stride;
TF_RETURN_IF_ERROR(
EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride));
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated,
&stride, outer_context));
if (!evaluated) {
*result = ctx->UnknownShape();
return Status::OK();
@ -630,14 +638,16 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
// Apply stride to input interpreted as a partial shape.
ShapeHandle input;
TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input));
TF_RETURN_IF_ERROR(
ConstantPartialShape(ctx, slice_node, 0, &input, outer_context));
TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
return Status::OK();
}
Status ShapeRefiner::RunShapeFn(const Node* node,
const OpRegistrationData* op_reg_data,
ExtendedInferenceContext* ec) {
ExtendedInferenceContext* ec,
InferenceContext* outer_context) {
// This will be filled in with real data in a second pass.
std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
std::vector<Tensor> real_tensors(node->num_inputs());
@ -719,8 +729,8 @@ Status ShapeRefiner::RunShapeFn(const Node* node,
Tensor result;
bool evaluated = false;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(
node, i, &evaluated, &result, outer_context));
if (evaluated) {
real_tensors[i] = result;
input_tensors[i] = &real_tensors[i];
@ -736,7 +746,7 @@ Status ShapeRefiner::RunShapeFn(const Node* node,
input_tensors_as_shapes.resize(i + 1);
}
ShapeHandle s;
TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context));
input_tensors_as_shapes[i] = s;
rerun_shape_fn = true;
}

View File

@ -184,17 +184,56 @@ class ShapeRefiner {
AttrSlice attributes,
ExtendedInferenceContext* outer_context);
// Performs shape inference for a node inside a function.
//
// 'outer_context' is the 'InferenceContext' for the function's call op.
Status InferShapesForFunctionSubNode(
const Node* node, shape_inference::InferenceContext* outer_context);
// Performs validation of 'node' and runs 'node's shape function,
// storing its shape outputs.
//
// All inputs of 'node' must be added to ShapeRefiner prior to
// adding 'node'.
//
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
// the call op of the function can be passed as 'outer_context' (pass nullptr
// otherwise). This gets used to perform constant propagation across Arg nodes
// by requesting the constant of value of the incoming tensor from the
// 'outer_context'.
//
// Returns an error if:
// - the shape function for 'node' was not registered.
// - 'node' was added before its inputs.
// - The shape inference function returns an error.
Status AddNodeInternal(const Node* node,
shape_inference::InferenceContext* outer_context);
// Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
// value can be evaluated, 'evaluated' is set to true and the value returned
// in 'result'. Otherwise 'evaluated' is set to false.
Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
bool* evaluated, Tensor* result);
//
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
// the call op of the function can be passed as 'outer_context' (pass nullptr
// otherwise). This gets used to perform constant propagation across Arg nodes
// by requesting the constant of value of the incoming tensor from the
// 'outer_context'.
Status EvaluateConstantTensorForEdge(
const Node* node, int dst_idx, bool* evaluated, Tensor* result,
shape_inference::InferenceContext* outer_context);
// Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
// tensors. The caller is responsible for checking that the specified edge is
// scalar and int32 or int64.
Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx,
bool* evaluated, int64* result);
//
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
// the call op of the function can be passed as 'outer_context' (pass nullptr
// otherwise). This gets used to perform constant propagation across Arg nodes
// by requesting the constant of value of the incoming tensor from the
// 'outer_context'.
Status EvaluateConstantIntScalarEdge(
const Node* node, int dst_idx, bool* evaluated, int64* result,
shape_inference::InferenceContext* outer_context);
// This function tries to materialize as much information about the 'node''s
// dst_idx input as a statically computable shape, and the result may be
@ -217,17 +256,39 @@ class ShapeRefiner {
//
// <target_context> is used when creating new DimensionHandle and ShapeHandle
// objects.
//
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
// the call op of the function can be passed as 'outer_context' (pass nullptr
// otherwise). This gets used to perform constant propagation across Arg nodes
// by requesting the constant of value of the incoming tensor from the
// 'outer_context'.
Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
const Node* node, int dst_idx,
shape_inference::ShapeHandle* result);
shape_inference::ShapeHandle* result,
shape_inference::InferenceContext* outer_context);
// Implementation of ConstantPartialShape for StridedSlice nodes.
Status PartialStridedSliceShape(Node* slice_node,
shape_inference::InferenceContext* ctx,
shape_inference::ShapeHandle* result);
//
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
// the call op of the function can be passed as 'outer_context' (pass nullptr
// otherwise). This gets used to perform constant propagation across Arg nodes
// by requesting the constant of value of the incoming tensor from the
// 'outer_context'.
Status PartialStridedSliceShape(
Node* slice_node, shape_inference::InferenceContext* ctx,
shape_inference::ShapeHandle* result,
shape_inference::InferenceContext* outer_context);
// Runs the shape function registered for the node's op type.
//
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
// the call op of the function can be passed as 'outer_context' (pass nullptr
// otherwise). This gets used to perform constant propagation across Arg nodes
// by requesting the constant of value of the incoming tensor from the
// 'outer_context'.
Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
ExtendedInferenceContext* ec);
ExtendedInferenceContext* ec,
shape_inference::InferenceContext* outer_context = nullptr);
int32 graph_def_version_;
const OpRegistryInterface* const ops_registry_;

View File

@ -719,7 +719,7 @@ Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
ShapeHandle input_shape;
TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
requested_input_tensor_as_partial_shape_[input_idx] = true;
request_input_tensor_as_partial_shape(input_idx);
const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
if (input_idx < input_tensors_as_shapes_size &&
input_tensors_as_shapes_[input_idx].IsSet() &&
@ -738,7 +738,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
ShapeHandle input_shape;
TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
requested_input_tensor_as_partial_shape_[input_idx] = true;
request_input_tensor_as_partial_shape(input_idx);
const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
if (input_idx < input_tensors_as_shapes_size &&
input_tensors_as_shapes_[input_idx].IsSet() &&

View File

@ -268,15 +268,31 @@ class InferenceContext {
// not available at the time of shape inference.
const Tensor* input_tensor(int idx) {
// Mark that this idx was requested.
requested_input_tensor_[idx] = true;
request_input_tensor(idx);
return input_tensors_[idx];
}
// Notifies the shape refiner that the value of the tensor at index <idx>
// is needed. The shape refiner tries to statically compute this tensor,
// and if successful re-runs the shape function with this tensor available
// in the call to 'input_tensor(idx)'.
void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; }
// Returns true iff input_tensor(idx) was called by the shape function.
bool requested_input_tensor(int idx) const {
return requested_input_tensor_[idx];
}
// Notifies the shape refiner that the value of the tensor at index <idx>
// as a partial shape is needed. The shape refiner tries to statically compute
// this, and if successful re-runs the shape function with the
// computed PartialTensorShape available in the call to
// 'MakeShapeFromShapeTensor(idx, handle)' or
// 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'.
void request_input_tensor_as_partial_shape(int idx) {
requested_input_tensor_as_partial_shape_[idx] = true;
}
// Returns true if MakeShapeFromInputTensor was called but the constant
// input_tensor was not present.
bool requested_input_tensor_as_partial_shape(int idx) const {

View File

@ -4297,6 +4297,92 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegex(TypeError, 'missing required arguments: y'):
foo.add(2) # pylint: disable=no-value-for-parameter
def testShapeInferencePropagateConstNestedStack(self):
@def_function.function(input_signature=[
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
])
def f(x, s):
old_shape = array_ops.shape(x)
new_shape = array_ops.stack([old_shape[0], s], axis=0)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
])
def g(x):
y = f(x, s=5)
assert y.shape.as_list() == [3, 5], y.shape.as_list()
return y
self.assertAllEqual(
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
def testShapeInferencePropagateConstNestedUnstackStack(self):
@def_function.function(input_signature=[
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
])
def f(x, s):
s0, _ = array_ops.unstack(array_ops.shape(x), axis=0)
new_shape = array_ops.stack([s0, s], axis=0)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
])
def g(x):
y = f(x, s=5)
assert y.shape.as_list() == [3, 5], y.shape.as_list()
return y
self.assertAllEqual(
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
def testShapeInferencePropagateConstNestedConcat(self):
@def_function.function(input_signature=[
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
])
def f(d1, d2, d3):
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@def_function.function()
def g():
y = f(1, 2, 3)
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
return y
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
def testShapeInferencePropagateConstDoubleNested(self):
@def_function.function(input_signature=[
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
])
def f(d1, d2, d3):
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@def_function.function()
def g():
y = def_function.function(f)(1, 2, 3)
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
return y
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
class MultiDeviceTest(test.TestCase, parameterized.TestCase):

View File

@ -2007,7 +2007,8 @@ class CentralCropTest(test_util.TensorFlowTestCase):
self.assertTrue(y.op.name.startswith("central_crop"))
class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
class PadToBoundingBoxTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def _PadToBoundingBox(self, x, offset_height, offset_width, target_height,
target_width, use_tensor_inputs):
@ -2172,7 +2173,10 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
"inner 3 dims of \\'image.shape\\' must be > 0",
use_tensor_inputs_options=[True])
def testBadParams(self):
def testBadParamsScalarInputs(self):
# In this test, inputs do not get converted to tensors before calling the
# tf.function. The error message here is raised in python
# since the python function has direct access to the scalars.
x_shape = [3, 3, 1]
x = np.zeros(x_shape)
@ -2187,9 +2191,49 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
"height must be <= target - offset"),
(0, 2, 4, 4,
"width must be <= target - offset"))
for config_item in test_config:
self._assertRaises(x, x_shape, *config_item)
self._assertRaises(
x, x_shape, *config_item, use_tensor_inputs_options=[False])
def testBadParamsTensorInputsEager(self):
# In this test inputs get converted to EagerTensors before calling the
# tf.function. The error message here is raised in python
# since the python function has direct access to the tensor's values.
with context.eager_mode():
x_shape = [3, 3, 1]
x = np.zeros(x_shape)
# Each line is a test configuration:
# offset_height, offset_width, target_height, target_width, err_msg
test_config = (
(-1, 0, 4, 4,
"offset_height must be >= 0"),
(0, -1, 4, 4,
"offset_width must be >= 0"),
(2, 0, 4, 4,
"height must be <= target - offset"),
(0, 2, 4, 4,
"width must be <= target - offset"))
for config_item in test_config:
self._assertRaises(
x, x_shape, *config_item, use_tensor_inputs_options=[True])
@parameterized.named_parameters([("OffsetHeight", (-1, 0, 4, 4)),
("OffsetWidth", (0, -1, 4, 4)),
("Height", (2, 0, 4, 4)),
("Width", (0, 2, 4, 4))])
def testBadParamsTensorInputsGraph(self, config):
# In this test inputs get converted to tensors before calling the
# tf.function. The error message here is raised during shape inference.
with context.graph_mode():
x_shape = [3, 3, 1]
x = np.zeros(x_shape)
self._assertRaises(
x,
x_shape,
*config,
"Paddings must be non-negative",
use_tensor_inputs_options=[True])
def testNameScope(self):
# Testing name scope requires a graph.