Support constant folding across Arg nodes during shape inference on nested function calls.
PiperOrigin-RevId: 337992497 Change-Id: If480c605fa34def3586a62105c5517c02c2f73e6
This commit is contained in:
parent
969055f71a
commit
f41a9a9335
@ -162,7 +162,8 @@
|
|||||||
stateful ops.
|
stateful ops.
|
||||||
* Added `tf.config.experimental.get_memory_usage` to return total memory
|
* Added `tf.config.experimental.get_memory_usage` to return total memory
|
||||||
usage of the device.
|
usage of the device.
|
||||||
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
|
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
|
||||||
|
* Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions.
|
||||||
* `tf.data`:
|
* `tf.data`:
|
||||||
* tf.data service:
|
* tf.data service:
|
||||||
* Added new `tf.data.experimental.service.register_dataset` and
|
* Added new `tf.data.experimental.service.register_dataset` and
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
@ -123,6 +124,17 @@ bool HasCpuKernel(const Node& node) {
|
|||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GetArgNodeIndex(const Node* node, int num_function_inputs, int* index) {
|
||||||
|
DCHECK(node->IsArg());
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", index));
|
||||||
|
if (*index < 0 || num_function_inputs <= *index) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Function instantiation included invalid input index: ", index,
|
||||||
|
" not in [0, ", num_function_inputs, ").");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Extracts the subgraph ending at 'target_node' that is statically computable
|
// Extracts the subgraph ending at 'target_node' that is statically computable
|
||||||
// and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
|
// and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
|
||||||
// will be set to true.
|
// will be set to true.
|
||||||
@ -130,7 +142,8 @@ Status ExtractConstantSubgraph(
|
|||||||
const Node& target_node, const ShapeRefiner& refiner,
|
const Node& target_node, const ShapeRefiner& refiner,
|
||||||
const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
|
const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
|
||||||
bool* is_constant_graph,
|
bool* is_constant_graph,
|
||||||
std::vector<std::pair<string, Tensor>>* const_inputs) {
|
std::vector<std::pair<string, Tensor>>* const_inputs,
|
||||||
|
InferenceContext* outer_context) {
|
||||||
*is_constant_graph = false;
|
*is_constant_graph = false;
|
||||||
std::unordered_set<string> const_inputs_added;
|
std::unordered_set<string> const_inputs_added;
|
||||||
|
|
||||||
@ -187,8 +200,9 @@ Status ExtractConstantSubgraph(
|
|||||||
edges_to_visit.pop_front();
|
edges_to_visit.pop_front();
|
||||||
Node* current_node = current_edge->src();
|
Node* current_node = current_edge->src();
|
||||||
|
|
||||||
// If the node is stateful, assume the graph is not constant.
|
// If the node is stateful, assume the graph is not constant unless it is
|
||||||
if (current_node->op_def().is_stateful()) {
|
// an Arg node which is handled later on.
|
||||||
|
if (!current_node->IsArg() && current_node->op_def().is_stateful()) {
|
||||||
*is_constant_graph = false;
|
*is_constant_graph = false;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -223,9 +237,32 @@ Status ExtractConstantSubgraph(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If there is nothing more to recurse down, see if
|
// If there is nothing more to recurse down, see if
|
||||||
// the generator node is a constant.
|
// the generator node is a constant or an Arg node whose value is available
|
||||||
|
// in the `outer_context`.
|
||||||
if (current_node->num_inputs() == 0) {
|
if (current_node->num_inputs() == 0) {
|
||||||
if (!current_node->IsConstant()) {
|
if (outer_context && current_node->IsArg()) {
|
||||||
|
const string& tensor_name =
|
||||||
|
strings::StrCat(current_node->name(), ":", 0);
|
||||||
|
// If we do not already have a constant Tensor for this Arg try to
|
||||||
|
// fetch it from the outer context.
|
||||||
|
if (const_inputs_added.count(tensor_name) == 0) {
|
||||||
|
int index;
|
||||||
|
TF_RETURN_IF_ERROR(GetArgNodeIndex(
|
||||||
|
current_node, outer_context->num_inputs(), &index));
|
||||||
|
const Tensor* const_tensor = outer_context->input_tensor(index);
|
||||||
|
if (const_tensor) {
|
||||||
|
const_inputs->emplace_back(tensor_name, *const_tensor);
|
||||||
|
const_inputs_added.insert(tensor_name);
|
||||||
|
} else {
|
||||||
|
// Request a constant value for this Arg. If that is statically
|
||||||
|
// computable, shape refiner will re-run the shape inference for
|
||||||
|
// this function with this tensor's value.
|
||||||
|
outer_context->request_input_tensor(index);
|
||||||
|
*is_constant_graph = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (!current_node->IsConstant()) {
|
||||||
// Generator node is not a constant, so subgraph is not
|
// Generator node is not a constant, so subgraph is not
|
||||||
// constant.
|
// constant.
|
||||||
*is_constant_graph = false;
|
*is_constant_graph = false;
|
||||||
@ -314,7 +351,8 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
|
|||||||
Tensor* result, GraphRunner* graph_runner,
|
Tensor* result, GraphRunner* graph_runner,
|
||||||
std::unordered_map<string, Tensor>* cached_values,
|
std::unordered_map<string, Tensor>* cached_values,
|
||||||
int64 max_cached_value_size,
|
int64 max_cached_value_size,
|
||||||
bool disable_constant_propagation) {
|
bool disable_constant_propagation,
|
||||||
|
InferenceContext* outer_context) {
|
||||||
*evaluated = false;
|
*evaluated = false;
|
||||||
const Node* src = tensor.node;
|
const Node* src = tensor.node;
|
||||||
|
|
||||||
@ -326,6 +364,22 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the source node is an Arg return its value, if available in the outer
|
||||||
|
// context.
|
||||||
|
if (src->IsArg() && outer_context) {
|
||||||
|
int index;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetArgNodeIndex(src, outer_context->num_inputs(), &index));
|
||||||
|
const Tensor* const_tensor = outer_context->input_tensor(index);
|
||||||
|
if (const_tensor) {
|
||||||
|
*evaluated = true;
|
||||||
|
*result = *(outer_context->input_tensor(index));
|
||||||
|
} else {
|
||||||
|
outer_context->request_input_tensor(index);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
if (disable_constant_propagation) {
|
if (disable_constant_propagation) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -339,7 +393,7 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
|
|||||||
std::vector<std::pair<string, Tensor>> const_inputs;
|
std::vector<std::pair<string, Tensor>> const_inputs;
|
||||||
TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
|
TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
|
||||||
&subgraph, &is_constant_graph,
|
&subgraph, &is_constant_graph,
|
||||||
&const_inputs));
|
&const_inputs, outer_context));
|
||||||
if (!is_constant_graph) {
|
if (!is_constant_graph) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -53,13 +53,17 @@ class Tensor;
|
|||||||
// result size to cache.
|
// result size to cache.
|
||||||
// disable_constant_propagation - if true, only Const node values will be
|
// disable_constant_propagation - if true, only Const node values will be
|
||||||
// returned.
|
// returned.
|
||||||
|
// outer_context - optional. The InferenceContext for the call node if inside
|
||||||
|
// a nested function. This is useful for doing constant propagation across
|
||||||
|
// Arg nodes.
|
||||||
Status EvaluateConstantTensor(
|
Status EvaluateConstantTensor(
|
||||||
OutputTensor tensor, const ShapeRefiner& refiner,
|
OutputTensor tensor, const ShapeRefiner& refiner,
|
||||||
const OpRegistryInterface& ops, int32 graph_def_version, bool* evaluated,
|
const OpRegistryInterface& ops, int32 graph_def_version, bool* evaluated,
|
||||||
Tensor* result, GraphRunner* graph_runner = nullptr,
|
Tensor* result, GraphRunner* graph_runner = nullptr,
|
||||||
std::unordered_map<string, Tensor>* cached_values = nullptr,
|
std::unordered_map<string, Tensor>* cached_values = nullptr,
|
||||||
int64 max_cached_value_size = 1024,
|
int64 max_cached_value_size = 1024,
|
||||||
bool disable_constant_propagation = false);
|
bool disable_constant_propagation = false,
|
||||||
|
shape_inference::InferenceContext* outer_context = nullptr);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
@ -59,13 +60,15 @@ namespace {
|
|||||||
constexpr char kArgOp[] = "_Arg";
|
constexpr char kArgOp[] = "_Arg";
|
||||||
constexpr char kRetvalOp[] = "_Retval";
|
constexpr char kRetvalOp[] = "_Retval";
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Runs shape inference for the given node using the given ShapeRefiner.
|
// Runs shape inference for the given node using the given ShapeRefiner.
|
||||||
// The node must be a sub-node of a function node and the outer_context is
|
// The node must be a sub-node of a function node and the outer_context is
|
||||||
// the inference context of that function node in the outer graph.
|
// the inference context of that function node in the outer graph.
|
||||||
Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
|
Status ShapeRefiner::InferShapesForFunctionSubNode(
|
||||||
InferenceContext* outer_context) {
|
const Node* node, InferenceContext* outer_context) {
|
||||||
TF_RETURN_IF_ERROR(refiner->AddNode(node));
|
TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context));
|
||||||
InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
|
InferenceContext* node_context = CHECK_NOTNULL(GetContext(node));
|
||||||
|
|
||||||
if (StringPiece(node->type_string()) == kArgOp) {
|
if (StringPiece(node->type_string()) == kArgOp) {
|
||||||
// Handle special node: function input.
|
// Handle special node: function input.
|
||||||
@ -126,8 +129,6 @@ Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// TODO(cwhipkey): When an inference context inside function has
|
// TODO(cwhipkey): When an inference context inside function has
|
||||||
// requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
|
// requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
|
||||||
// set when input(i) is an _Arg op, then this request should propagate to
|
// set when input(i) is an _Arg op, then this request should propagate to
|
||||||
@ -167,8 +168,8 @@ Status ShapeRefiner::InferShapesForFunction(
|
|||||||
auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
|
auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
|
||||||
&inference_status](const Node* node) {
|
&inference_status](const Node* node) {
|
||||||
if (!inference_status.ok()) return;
|
if (!inference_status.ok()) return;
|
||||||
inference_status = InferShapesForFunctionSubNode(
|
inference_status =
|
||||||
node, this, outer_context->get_context());
|
InferShapesForFunctionSubNode(node, outer_context->get_context());
|
||||||
function_nodes.insert(node);
|
function_nodes.insert(node);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -187,6 +188,11 @@ Status ShapeRefiner::InferShapesForFunction(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeRefiner::AddNode(const Node* node) {
|
Status ShapeRefiner::AddNode(const Node* node) {
|
||||||
|
return AddNodeInternal(node, /*outer_context=*/nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ShapeRefiner::AddNodeInternal(
|
||||||
|
const Node* node, shape_inference::InferenceContext* outer_context) {
|
||||||
// Create the inference context for this node with the existing input shapes.
|
// Create the inference context for this node with the existing input shapes.
|
||||||
std::unique_ptr<InferenceContext> ic(new InferenceContext(
|
std::unique_ptr<InferenceContext> ic(new InferenceContext(
|
||||||
graph_def_version_, node->def(), node->op_def(),
|
graph_def_version_, node->def(), node->op_def(),
|
||||||
@ -240,7 +246,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
|||||||
new ExtendedInferenceContext(std::move(ic), node));
|
new ExtendedInferenceContext(std::move(ic), node));
|
||||||
|
|
||||||
// Run the shape inference function, and return if there was an error.
|
// Run the shape inference function, and return if there was an error.
|
||||||
TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
|
TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context));
|
||||||
|
|
||||||
// Store the resulting context object in the map.
|
// Store the resulting context object in the map.
|
||||||
node_to_context_[node].swap(ec);
|
node_to_context_[node].swap(ec);
|
||||||
@ -385,25 +391,25 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
|
|||||||
return RunShapeFn(node, op_reg_data, node_ext_context);
|
return RunShapeFn(node, op_reg_data, node_ext_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
|
Status ShapeRefiner::EvaluateConstantTensorForEdge(
|
||||||
int dst_idx, bool* evaluated,
|
const Node* node, int dst_idx, bool* evaluated, Tensor* result,
|
||||||
Tensor* result) {
|
InferenceContext* outer_context) {
|
||||||
*evaluated = false;
|
*evaluated = false;
|
||||||
const Edge* input_edge;
|
const Edge* input_edge;
|
||||||
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
|
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
|
||||||
OutputTensor tensor(input_edge->src(), input_edge->src_output());
|
OutputTensor tensor(input_edge->src(), input_edge->src_output());
|
||||||
return EvaluateConstantTensor(tensor, *this, *ops_registry_,
|
return EvaluateConstantTensor(
|
||||||
graph_def_version_, evaluated, result,
|
tensor, *this, *ops_registry_, graph_def_version_, evaluated, result,
|
||||||
&graph_runner_, &const_tensor_map_,
|
&graph_runner_, &const_tensor_map_, kMaxTensorSize,
|
||||||
kMaxTensorSize, disable_constant_propagation_);
|
disable_constant_propagation_, outer_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
|
Status ShapeRefiner::EvaluateConstantIntScalarEdge(
|
||||||
int dst_idx, bool* evaluated,
|
const Node* node, int dst_idx, bool* evaluated, int64* result,
|
||||||
int64* result) {
|
shape_inference::InferenceContext* outer_context) {
|
||||||
Tensor scalar;
|
Tensor scalar;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated,
|
||||||
EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar));
|
&scalar, outer_context));
|
||||||
if (*evaluated) {
|
if (*evaluated) {
|
||||||
if (scalar.NumElements() != 1) {
|
if (scalar.NumElements() != 1) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
@ -424,9 +430,9 @@ Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
Status ShapeRefiner::ConstantPartialShape(
|
||||||
const Node* node, int dst_idx,
|
InferenceContext* target_context, const Node* node, int dst_idx,
|
||||||
ShapeHandle* result) {
|
ShapeHandle* result, shape_inference::InferenceContext* outer_context) {
|
||||||
const Edge* input_edge;
|
const Edge* input_edge;
|
||||||
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
|
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
|
||||||
|
|
||||||
@ -437,8 +443,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
|||||||
if (src_context->Value(src_context->Rank(src_shape)) == 0) {
|
if (src_context->Value(src_context->Rank(src_shape)) == 0) {
|
||||||
Tensor t;
|
Tensor t;
|
||||||
bool evaluated = false;
|
bool evaluated = false;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
|
||||||
EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
|
&t, outer_context));
|
||||||
if (!evaluated) {
|
if (!evaluated) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Received a shape scalar with unknown static value. A static value "
|
"Received a shape scalar with unknown static value. A static value "
|
||||||
@ -471,7 +477,9 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
|||||||
// a float.
|
// a float.
|
||||||
Tensor t;
|
Tensor t;
|
||||||
bool evaluated = false;
|
bool evaluated = false;
|
||||||
if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t).ok()) {
|
if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t,
|
||||||
|
outer_context)
|
||||||
|
.ok()) {
|
||||||
if (evaluated &&
|
if (evaluated &&
|
||||||
target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
|
target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -481,7 +489,7 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
|||||||
// Then try to infer partial shape from the input to the cast tensor.
|
// Then try to infer partial shape from the input to the cast tensor.
|
||||||
ShapeHandle pre_cast_shape;
|
ShapeHandle pre_cast_shape;
|
||||||
if (!ConstantPartialShape(target_context, input_edge->src(), 0,
|
if (!ConstantPartialShape(target_context, input_edge->src(), 0,
|
||||||
&pre_cast_shape)
|
&pre_cast_shape, outer_context)
|
||||||
.ok()) {
|
.ok()) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
target_context->MakeShapeFromTensor(nullptr, src_shape, result));
|
target_context->MakeShapeFromTensor(nullptr, src_shape, result));
|
||||||
@ -510,8 +518,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
|||||||
for (int i = 0; i < src_context->num_inputs(); ++i) {
|
for (int i = 0; i < src_context->num_inputs(); ++i) {
|
||||||
int64 size;
|
int64 size;
|
||||||
bool evaluated;
|
bool evaluated;
|
||||||
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i,
|
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(
|
||||||
&evaluated, &size));
|
input_edge->src(), i, &evaluated, &size, outer_context));
|
||||||
if (evaluated) {
|
if (evaluated) {
|
||||||
dims.push_back(size < 0 ? target_context->UnknownDim()
|
dims.push_back(size < 0 ? target_context->UnknownDim()
|
||||||
: target_context->MakeDim(size));
|
: target_context->MakeDim(size));
|
||||||
@ -531,7 +539,7 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
|||||||
if (i == concat_dim) continue;
|
if (i == concat_dim) continue;
|
||||||
ShapeHandle sub_result;
|
ShapeHandle sub_result;
|
||||||
TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
|
TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
|
||||||
i, &sub_result));
|
i, &sub_result, outer_context));
|
||||||
if (!target_context->RankKnown(sub_result)) {
|
if (!target_context->RankKnown(sub_result)) {
|
||||||
// Failed to evaluate. Treat the output as completely unknown.
|
// Failed to evaluate. Treat the output as completely unknown.
|
||||||
// TODO(cwhipkey): we could rely on all inputs being the same rank, so
|
// TODO(cwhipkey): we could rely on all inputs being the same rank, so
|
||||||
@ -543,8 +551,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
|||||||
target_context->Concatenate(*result, sub_result, result));
|
target_context->Concatenate(*result, sub_result, result));
|
||||||
}
|
}
|
||||||
} else if (src_op == "StridedSlice") {
|
} else if (src_op == "StridedSlice") {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context,
|
||||||
PartialStridedSliceShape(input_edge->src(), src_context, result));
|
result, outer_context));
|
||||||
} else if (src_op == "VariableShape") {
|
} else if (src_op == "VariableShape") {
|
||||||
auto* handle_data = src_context->input_handle_shapes_and_types(0);
|
auto* handle_data = src_context->input_handle_shapes_and_types(0);
|
||||||
if (handle_data != nullptr && !handle_data->empty()) {
|
if (handle_data != nullptr && !handle_data->empty()) {
|
||||||
@ -555,17 +563,17 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
|||||||
} else {
|
} else {
|
||||||
Tensor t;
|
Tensor t;
|
||||||
bool evaluated = false;
|
bool evaluated = false;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
|
||||||
EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
|
&t, outer_context));
|
||||||
TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
|
TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
|
||||||
evaluated ? &t : nullptr, src_shape, result));
|
evaluated ? &t : nullptr, src_shape, result));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
Status ShapeRefiner::PartialStridedSliceShape(
|
||||||
InferenceContext* ctx,
|
Node* slice_node, InferenceContext* ctx, ShapeHandle* result,
|
||||||
ShapeHandle* result) {
|
shape_inference::InferenceContext* outer_context) {
|
||||||
// Only attempt to evaluate if begin/end/strides all are scalars.
|
// Only attempt to evaluate if begin/end/strides all are scalars.
|
||||||
for (int i = 1; i <= 3; ++i) {
|
for (int i = 1; i <= 3; ++i) {
|
||||||
ShapeHandle input_shape = ctx->input(i);
|
ShapeHandle input_shape = ctx->input(i);
|
||||||
@ -600,8 +608,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
|||||||
if (begin_mask == 1) {
|
if (begin_mask == 1) {
|
||||||
begin = 0;
|
begin = 0;
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated,
|
||||||
EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin));
|
&begin, outer_context));
|
||||||
if (!evaluated) {
|
if (!evaluated) {
|
||||||
*result = ctx->UnknownShape();
|
*result = ctx->UnknownShape();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -612,8 +620,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
|||||||
if (end_mask == 1) {
|
if (end_mask == 1) {
|
||||||
end = std::numeric_limits<int64>::max();
|
end = std::numeric_limits<int64>::max();
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated,
|
||||||
EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end));
|
&end, outer_context));
|
||||||
if (!evaluated) {
|
if (!evaluated) {
|
||||||
*result = ctx->UnknownShape();
|
*result = ctx->UnknownShape();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -621,8 +629,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64 stride;
|
int64 stride;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated,
|
||||||
EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride));
|
&stride, outer_context));
|
||||||
if (!evaluated) {
|
if (!evaluated) {
|
||||||
*result = ctx->UnknownShape();
|
*result = ctx->UnknownShape();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -630,14 +638,16 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
|
|||||||
|
|
||||||
// Apply stride to input interpreted as a partial shape.
|
// Apply stride to input interpreted as a partial shape.
|
||||||
ShapeHandle input;
|
ShapeHandle input;
|
||||||
TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input));
|
TF_RETURN_IF_ERROR(
|
||||||
|
ConstantPartialShape(ctx, slice_node, 0, &input, outer_context));
|
||||||
TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
|
TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeRefiner::RunShapeFn(const Node* node,
|
Status ShapeRefiner::RunShapeFn(const Node* node,
|
||||||
const OpRegistrationData* op_reg_data,
|
const OpRegistrationData* op_reg_data,
|
||||||
ExtendedInferenceContext* ec) {
|
ExtendedInferenceContext* ec,
|
||||||
|
InferenceContext* outer_context) {
|
||||||
// This will be filled in with real data in a second pass.
|
// This will be filled in with real data in a second pass.
|
||||||
std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
|
std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
|
||||||
std::vector<Tensor> real_tensors(node->num_inputs());
|
std::vector<Tensor> real_tensors(node->num_inputs());
|
||||||
@ -719,8 +729,8 @@ Status ShapeRefiner::RunShapeFn(const Node* node,
|
|||||||
|
|
||||||
Tensor result;
|
Tensor result;
|
||||||
bool evaluated = false;
|
bool evaluated = false;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(
|
||||||
EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
|
node, i, &evaluated, &result, outer_context));
|
||||||
if (evaluated) {
|
if (evaluated) {
|
||||||
real_tensors[i] = result;
|
real_tensors[i] = result;
|
||||||
input_tensors[i] = &real_tensors[i];
|
input_tensors[i] = &real_tensors[i];
|
||||||
@ -736,7 +746,7 @@ Status ShapeRefiner::RunShapeFn(const Node* node,
|
|||||||
input_tensors_as_shapes.resize(i + 1);
|
input_tensors_as_shapes.resize(i + 1);
|
||||||
}
|
}
|
||||||
ShapeHandle s;
|
ShapeHandle s;
|
||||||
TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
|
TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context));
|
||||||
input_tensors_as_shapes[i] = s;
|
input_tensors_as_shapes[i] = s;
|
||||||
rerun_shape_fn = true;
|
rerun_shape_fn = true;
|
||||||
}
|
}
|
||||||
|
@ -184,17 +184,56 @@ class ShapeRefiner {
|
|||||||
AttrSlice attributes,
|
AttrSlice attributes,
|
||||||
ExtendedInferenceContext* outer_context);
|
ExtendedInferenceContext* outer_context);
|
||||||
|
|
||||||
|
// Performs shape inference for a node inside a function.
|
||||||
|
//
|
||||||
|
// 'outer_context' is the 'InferenceContext' for the function's call op.
|
||||||
|
Status InferShapesForFunctionSubNode(
|
||||||
|
const Node* node, shape_inference::InferenceContext* outer_context);
|
||||||
|
|
||||||
|
// Performs validation of 'node' and runs 'node's shape function,
|
||||||
|
// storing its shape outputs.
|
||||||
|
//
|
||||||
|
// All inputs of 'node' must be added to ShapeRefiner prior to
|
||||||
|
// adding 'node'.
|
||||||
|
//
|
||||||
|
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||||
|
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||||
|
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||||
|
// by requesting the constant of value of the incoming tensor from the
|
||||||
|
// 'outer_context'.
|
||||||
|
//
|
||||||
|
// Returns an error if:
|
||||||
|
// - the shape function for 'node' was not registered.
|
||||||
|
// - 'node' was added before its inputs.
|
||||||
|
// - The shape inference function returns an error.
|
||||||
|
Status AddNodeInternal(const Node* node,
|
||||||
|
shape_inference::InferenceContext* outer_context);
|
||||||
|
|
||||||
// Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
|
// Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
|
||||||
// value can be evaluated, 'evaluated' is set to true and the value returned
|
// value can be evaluated, 'evaluated' is set to true and the value returned
|
||||||
// in 'result'. Otherwise 'evaluated' is set to false.
|
// in 'result'. Otherwise 'evaluated' is set to false.
|
||||||
Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
|
//
|
||||||
bool* evaluated, Tensor* result);
|
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||||
|
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||||
|
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||||
|
// by requesting the constant of value of the incoming tensor from the
|
||||||
|
// 'outer_context'.
|
||||||
|
Status EvaluateConstantTensorForEdge(
|
||||||
|
const Node* node, int dst_idx, bool* evaluated, Tensor* result,
|
||||||
|
shape_inference::InferenceContext* outer_context);
|
||||||
|
|
||||||
// Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
|
// Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
|
||||||
// tensors. The caller is responsible for checking that the specified edge is
|
// tensors. The caller is responsible for checking that the specified edge is
|
||||||
// scalar and int32 or int64.
|
// scalar and int32 or int64.
|
||||||
Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx,
|
//
|
||||||
bool* evaluated, int64* result);
|
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||||
|
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||||
|
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||||
|
// by requesting the constant of value of the incoming tensor from the
|
||||||
|
// 'outer_context'.
|
||||||
|
Status EvaluateConstantIntScalarEdge(
|
||||||
|
const Node* node, int dst_idx, bool* evaluated, int64* result,
|
||||||
|
shape_inference::InferenceContext* outer_context);
|
||||||
|
|
||||||
// This function tries to materialize as much information about the 'node''s
|
// This function tries to materialize as much information about the 'node''s
|
||||||
// dst_idx input as a statically computable shape, and the result may be
|
// dst_idx input as a statically computable shape, and the result may be
|
||||||
@ -217,17 +256,39 @@ class ShapeRefiner {
|
|||||||
//
|
//
|
||||||
// <target_context> is used when creating new DimensionHandle and ShapeHandle
|
// <target_context> is used when creating new DimensionHandle and ShapeHandle
|
||||||
// objects.
|
// objects.
|
||||||
|
//
|
||||||
|
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||||
|
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||||
|
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||||
|
// by requesting the constant of value of the incoming tensor from the
|
||||||
|
// 'outer_context'.
|
||||||
Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
|
Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
|
||||||
const Node* node, int dst_idx,
|
const Node* node, int dst_idx,
|
||||||
shape_inference::ShapeHandle* result);
|
shape_inference::ShapeHandle* result,
|
||||||
|
shape_inference::InferenceContext* outer_context);
|
||||||
|
|
||||||
// Implementation of ConstantPartialShape for StridedSlice nodes.
|
// Implementation of ConstantPartialShape for StridedSlice nodes.
|
||||||
Status PartialStridedSliceShape(Node* slice_node,
|
//
|
||||||
shape_inference::InferenceContext* ctx,
|
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||||
shape_inference::ShapeHandle* result);
|
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||||
|
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||||
|
// by requesting the constant of value of the incoming tensor from the
|
||||||
|
// 'outer_context'.
|
||||||
|
Status PartialStridedSliceShape(
|
||||||
|
Node* slice_node, shape_inference::InferenceContext* ctx,
|
||||||
|
shape_inference::ShapeHandle* result,
|
||||||
|
shape_inference::InferenceContext* outer_context);
|
||||||
|
|
||||||
|
// Runs the shape function registered for the node's op type.
|
||||||
|
//
|
||||||
|
// Optionally, if 'node' is in a nested function, the 'InferenceContext' for
|
||||||
|
// the call op of the function can be passed as 'outer_context' (pass nullptr
|
||||||
|
// otherwise). This gets used to perform constant propagation across Arg nodes
|
||||||
|
// by requesting the constant of value of the incoming tensor from the
|
||||||
|
// 'outer_context'.
|
||||||
Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
|
Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
|
||||||
ExtendedInferenceContext* ec);
|
ExtendedInferenceContext* ec,
|
||||||
|
shape_inference::InferenceContext* outer_context = nullptr);
|
||||||
|
|
||||||
int32 graph_def_version_;
|
int32 graph_def_version_;
|
||||||
const OpRegistryInterface* const ops_registry_;
|
const OpRegistryInterface* const ops_registry_;
|
||||||
|
@ -719,7 +719,7 @@ Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
|||||||
ShapeHandle input_shape;
|
ShapeHandle input_shape;
|
||||||
TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
|
TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
|
||||||
|
|
||||||
requested_input_tensor_as_partial_shape_[input_idx] = true;
|
request_input_tensor_as_partial_shape(input_idx);
|
||||||
const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
|
const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
|
||||||
if (input_idx < input_tensors_as_shapes_size &&
|
if (input_idx < input_tensors_as_shapes_size &&
|
||||||
input_tensors_as_shapes_[input_idx].IsSet() &&
|
input_tensors_as_shapes_[input_idx].IsSet() &&
|
||||||
@ -738,7 +738,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
|
|||||||
ShapeHandle input_shape;
|
ShapeHandle input_shape;
|
||||||
TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
|
TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
|
||||||
|
|
||||||
requested_input_tensor_as_partial_shape_[input_idx] = true;
|
request_input_tensor_as_partial_shape(input_idx);
|
||||||
const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
|
const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
|
||||||
if (input_idx < input_tensors_as_shapes_size &&
|
if (input_idx < input_tensors_as_shapes_size &&
|
||||||
input_tensors_as_shapes_[input_idx].IsSet() &&
|
input_tensors_as_shapes_[input_idx].IsSet() &&
|
||||||
|
@ -268,15 +268,31 @@ class InferenceContext {
|
|||||||
// not available at the time of shape inference.
|
// not available at the time of shape inference.
|
||||||
const Tensor* input_tensor(int idx) {
|
const Tensor* input_tensor(int idx) {
|
||||||
// Mark that this idx was requested.
|
// Mark that this idx was requested.
|
||||||
requested_input_tensor_[idx] = true;
|
request_input_tensor(idx);
|
||||||
return input_tensors_[idx];
|
return input_tensors_[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Notifies the shape refiner that the value of the tensor at index <idx>
|
||||||
|
// is needed. The shape refiner tries to statically compute this tensor,
|
||||||
|
// and if successful re-runs the shape function with this tensor available
|
||||||
|
// in the call to 'input_tensor(idx)'.
|
||||||
|
void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; }
|
||||||
|
|
||||||
// Returns true iff input_tensor(idx) was called by the shape function.
|
// Returns true iff input_tensor(idx) was called by the shape function.
|
||||||
bool requested_input_tensor(int idx) const {
|
bool requested_input_tensor(int idx) const {
|
||||||
return requested_input_tensor_[idx];
|
return requested_input_tensor_[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Notifies the shape refiner that the value of the tensor at index <idx>
|
||||||
|
// as a partial shape is needed. The shape refiner tries to statically compute
|
||||||
|
// this, and if successful re-runs the shape function with the
|
||||||
|
// computed PartialTensorShape available in the call to
|
||||||
|
// 'MakeShapeFromShapeTensor(idx, handle)' or
|
||||||
|
// 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'.
|
||||||
|
void request_input_tensor_as_partial_shape(int idx) {
|
||||||
|
requested_input_tensor_as_partial_shape_[idx] = true;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns true if MakeShapeFromInputTensor was called but the constant
|
// Returns true if MakeShapeFromInputTensor was called but the constant
|
||||||
// input_tensor was not present.
|
// input_tensor was not present.
|
||||||
bool requested_input_tensor_as_partial_shape(int idx) const {
|
bool requested_input_tensor_as_partial_shape(int idx) const {
|
||||||
|
@ -4297,6 +4297,92 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaisesRegex(TypeError, 'missing required arguments: y'):
|
with self.assertRaisesRegex(TypeError, 'missing required arguments: y'):
|
||||||
foo.add(2) # pylint: disable=no-value-for-parameter
|
foo.add(2) # pylint: disable=no-value-for-parameter
|
||||||
|
|
||||||
|
def testShapeInferencePropagateConstNestedStack(self):
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
|
||||||
|
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||||
|
])
|
||||||
|
def f(x, s):
|
||||||
|
old_shape = array_ops.shape(x)
|
||||||
|
new_shape = array_ops.stack([old_shape[0], s], axis=0)
|
||||||
|
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
|
||||||
|
])
|
||||||
|
def g(x):
|
||||||
|
y = f(x, s=5)
|
||||||
|
assert y.shape.as_list() == [3, 5], y.shape.as_list()
|
||||||
|
return y
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
|
||||||
|
|
||||||
|
def testShapeInferencePropagateConstNestedUnstackStack(self):
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
|
||||||
|
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||||
|
])
|
||||||
|
def f(x, s):
|
||||||
|
s0, _ = array_ops.unstack(array_ops.shape(x), axis=0)
|
||||||
|
new_shape = array_ops.stack([s0, s], axis=0)
|
||||||
|
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
|
||||||
|
])
|
||||||
|
def g(x):
|
||||||
|
y = f(x, s=5)
|
||||||
|
assert y.shape.as_list() == [3, 5], y.shape.as_list()
|
||||||
|
return y
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
|
||||||
|
|
||||||
|
def testShapeInferencePropagateConstNestedConcat(self):
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||||
|
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||||
|
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||||
|
])
|
||||||
|
def f(d1, d2, d3):
|
||||||
|
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
|
||||||
|
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@def_function.function()
|
||||||
|
def g():
|
||||||
|
y = f(1, 2, 3)
|
||||||
|
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
|
||||||
|
return y
|
||||||
|
|
||||||
|
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
|
||||||
|
|
||||||
|
def testShapeInferencePropagateConstDoubleNested(self):
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||||
|
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||||
|
tensor_spec.TensorSpec((), dtype=dtypes.int32),
|
||||||
|
])
|
||||||
|
def f(d1, d2, d3):
|
||||||
|
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
|
||||||
|
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@def_function.function()
|
||||||
|
def g():
|
||||||
|
y = def_function.function(f)(1, 2, 3)
|
||||||
|
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
|
||||||
|
return y
|
||||||
|
|
||||||
|
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
|
||||||
|
|
||||||
|
|
||||||
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@ -2007,7 +2007,8 @@ class CentralCropTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(y.op.name.startswith("central_crop"))
|
self.assertTrue(y.op.name.startswith("central_crop"))
|
||||||
|
|
||||||
|
|
||||||
class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
|
class PadToBoundingBoxTest(test_util.TensorFlowTestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
def _PadToBoundingBox(self, x, offset_height, offset_width, target_height,
|
def _PadToBoundingBox(self, x, offset_height, offset_width, target_height,
|
||||||
target_width, use_tensor_inputs):
|
target_width, use_tensor_inputs):
|
||||||
@ -2172,7 +2173,10 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
|
|||||||
"inner 3 dims of \\'image.shape\\' must be > 0",
|
"inner 3 dims of \\'image.shape\\' must be > 0",
|
||||||
use_tensor_inputs_options=[True])
|
use_tensor_inputs_options=[True])
|
||||||
|
|
||||||
def testBadParams(self):
|
def testBadParamsScalarInputs(self):
|
||||||
|
# In this test, inputs do not get converted to tensors before calling the
|
||||||
|
# tf.function. The error message here is raised in python
|
||||||
|
# since the python function has direct access to the scalars.
|
||||||
x_shape = [3, 3, 1]
|
x_shape = [3, 3, 1]
|
||||||
x = np.zeros(x_shape)
|
x = np.zeros(x_shape)
|
||||||
|
|
||||||
@ -2187,9 +2191,49 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
|
|||||||
"height must be <= target - offset"),
|
"height must be <= target - offset"),
|
||||||
(0, 2, 4, 4,
|
(0, 2, 4, 4,
|
||||||
"width must be <= target - offset"))
|
"width must be <= target - offset"))
|
||||||
|
|
||||||
for config_item in test_config:
|
for config_item in test_config:
|
||||||
self._assertRaises(x, x_shape, *config_item)
|
self._assertRaises(
|
||||||
|
x, x_shape, *config_item, use_tensor_inputs_options=[False])
|
||||||
|
|
||||||
|
def testBadParamsTensorInputsEager(self):
|
||||||
|
# In this test inputs get converted to EagerTensors before calling the
|
||||||
|
# tf.function. The error message here is raised in python
|
||||||
|
# since the python function has direct access to the tensor's values.
|
||||||
|
with context.eager_mode():
|
||||||
|
x_shape = [3, 3, 1]
|
||||||
|
x = np.zeros(x_shape)
|
||||||
|
|
||||||
|
# Each line is a test configuration:
|
||||||
|
# offset_height, offset_width, target_height, target_width, err_msg
|
||||||
|
test_config = (
|
||||||
|
(-1, 0, 4, 4,
|
||||||
|
"offset_height must be >= 0"),
|
||||||
|
(0, -1, 4, 4,
|
||||||
|
"offset_width must be >= 0"),
|
||||||
|
(2, 0, 4, 4,
|
||||||
|
"height must be <= target - offset"),
|
||||||
|
(0, 2, 4, 4,
|
||||||
|
"width must be <= target - offset"))
|
||||||
|
for config_item in test_config:
|
||||||
|
self._assertRaises(
|
||||||
|
x, x_shape, *config_item, use_tensor_inputs_options=[True])
|
||||||
|
|
||||||
|
@parameterized.named_parameters([("OffsetHeight", (-1, 0, 4, 4)),
|
||||||
|
("OffsetWidth", (0, -1, 4, 4)),
|
||||||
|
("Height", (2, 0, 4, 4)),
|
||||||
|
("Width", (0, 2, 4, 4))])
|
||||||
|
def testBadParamsTensorInputsGraph(self, config):
|
||||||
|
# In this test inputs get converted to tensors before calling the
|
||||||
|
# tf.function. The error message here is raised during shape inference.
|
||||||
|
with context.graph_mode():
|
||||||
|
x_shape = [3, 3, 1]
|
||||||
|
x = np.zeros(x_shape)
|
||||||
|
self._assertRaises(
|
||||||
|
x,
|
||||||
|
x_shape,
|
||||||
|
*config,
|
||||||
|
"Paddings must be non-negative",
|
||||||
|
use_tensor_inputs_options=[True])
|
||||||
|
|
||||||
def testNameScope(self):
|
def testNameScope(self):
|
||||||
# Testing name scope requires a graph.
|
# Testing name scope requires a graph.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user