Handle input_tensors in C shape inference API.
PiperOrigin-RevId: 261761496
This commit is contained in:
parent
9683994327
commit
038d0cdabc
@ -1053,6 +1053,10 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
|
||||
delete[] shape_list_array;
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
} // namespace tensorflow
|
||||
|
||||
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
TF_Tensor** input_tensors,
|
||||
TF_ShapeAndTypeList* input_tensors_as_shapes,
|
||||
@ -1082,10 +1086,26 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
// Initialize a input_tensor vector with `nullptr` values.
|
||||
std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
|
||||
// A vector to keep track of newly created `tf::Tensor` objects.
|
||||
std::vector<Tensor> all_input_tensors;
|
||||
// Update the vector with information from `input_tensors` if provided.
|
||||
if (input_tensors != nullptr) {
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (input_tensors[i] == nullptr) continue;
|
||||
all_input_tensors.emplace_back();
|
||||
Tensor& input_tensor = all_input_tensors.back();
|
||||
status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
|
||||
if (!status->status.ok()) return;
|
||||
input_tensors_vector[i] = &input_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
// Create an inference context with dummy values, which will be updated later.
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def,
|
||||
std::vector<ShapeHandle>(num_inputs),
|
||||
std::vector<const Tensor*>(num_inputs, nullptr), {},
|
||||
std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
|
||||
{},
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
|
||||
|
||||
// Set input_shapes.
|
||||
@ -1102,7 +1122,6 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
c.SetInput(i, c.MakeShape(dims));
|
||||
}
|
||||
|
||||
// TODO(bgogul): Handle input_tensors.
|
||||
// TODO(bgogul): Handle input_tensors_as_shapes.
|
||||
// TODO(bgogul): Handle input_resource_shapes_and_types.
|
||||
|
||||
|
@ -378,10 +378,15 @@ TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeList(
|
||||
TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeListArray(
|
||||
TF_ShapeAndTypeList** shape_list_array, int num_items);
|
||||
|
||||
// Infer shapes for the given `node_def`. The arguments mimic the arguments of
|
||||
// the `shape_inference::InferenceContext` constructor. The types need not be
|
||||
// set in `input_shapes` as it is not used for shape inference. The number of
|
||||
// `input_tensors` should be the same as the number of items in `input_shapes`.
|
||||
// Infer shapes for the given `op`. The arguments mimic the arguments of the
|
||||
// `shape_inference::InferenceContext` constructor. Note the following:
|
||||
// - The inputs of the `op` are not used for shape inference. So, it is
|
||||
// OK to not have the inputs properly set in `op`. See `input_tensors`
|
||||
// if you want shape inference to consider the input tensors of the
|
||||
// op for shape inference.
|
||||
// - The types need not be set in `input_shapes` as it is not used.
|
||||
// - The number of `input_tensors` should be the same as the number of items
|
||||
// in `input_shapes`.
|
||||
//
|
||||
// The results are returned in `output_shapes` and
|
||||
// `output_resource_shapes_and_types`. The caller is responsible for freeing the
|
||||
|
@ -439,7 +439,6 @@ class ShapeInferenceTest : public ::testing::Test {
|
||||
: status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) {
|
||||
tfe_context_ = TFE_NewContext(tfe_context_options_, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
}
|
||||
|
||||
~ShapeInferenceTest() override {
|
||||
@ -452,7 +451,7 @@ class ShapeInferenceTest : public ::testing::Test {
|
||||
void CheckOutputShapes(
|
||||
TFE_Op* op,
|
||||
const std::vector<absl::optional<std::vector<int64_t>>>& input_shapes_vec,
|
||||
TF_Tensor** input_tensors,
|
||||
const std::vector<TF_Tensor*>& input_tensors,
|
||||
const absl::optional<std::vector<int64_t>>& expected_shape) {
|
||||
// Create input_shapes.
|
||||
TF_ShapeAndTypeList* input_shapes =
|
||||
@ -467,7 +466,10 @@ class ShapeInferenceTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
TF_ShapeAndTypeList* output_shapes;
|
||||
TFE_InferShapes(op, input_shapes, input_tensors,
|
||||
TFE_InferShapes(op, input_shapes,
|
||||
input_tensors.empty()
|
||||
? nullptr
|
||||
: const_cast<TF_Tensor**>(input_tensors.data()),
|
||||
/*input_tensors_as_shapes*/ nullptr,
|
||||
/*input_resource_shapes_and_types*/ nullptr, &output_shapes,
|
||||
/*output_resource_shapes_and_types*/ nullptr, status_);
|
||||
@ -515,31 +517,65 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) {
|
||||
// Infer shape when everything is known.
|
||||
CheckOutputShapes(matmul_op,
|
||||
/*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})},
|
||||
/*input_tensors*/ nullptr,
|
||||
/*input_tensors*/ {},
|
||||
/*expected_shape*/ make_shape({3, 4}));
|
||||
|
||||
// Infer shape when second operand has unknown shape.
|
||||
CheckOutputShapes(matmul_op,
|
||||
/*input_shapes*/ {make_shape({3, 2}), unknown_shape()},
|
||||
/*input_tensors*/ nullptr,
|
||||
/*input_tensors*/ {},
|
||||
/*expected_shape*/ make_shape({3, kUnknownDim}));
|
||||
|
||||
// Infer shape when some dimensions are unknown.
|
||||
CheckOutputShapes(
|
||||
matmul_op,
|
||||
/*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})},
|
||||
/*input_tensors*/ nullptr,
|
||||
/*input_tensors*/ {},
|
||||
/*expected_shape*/ make_shape({kUnknownDim, 4}));
|
||||
|
||||
// Infer shape when everything is unknown.
|
||||
CheckOutputShapes(matmul_op,
|
||||
/*input_shapes*/ {unknown_shape(), unknown_shape()},
|
||||
/*input_tensors*/ nullptr,
|
||||
/*input_tensors*/ {},
|
||||
/*expected_shape*/ make_shape({kUnknownDim, kUnknownDim}));
|
||||
|
||||
TFE_DeleteOp(matmul_op);
|
||||
// TODO(bgogul): Add some death tests where status is not OK.
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
||||
// Prepare some tensors for shape.
|
||||
TF_Tensor* tensor_1X6 = Int32Tensor({1, 6});
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TF_Tensor* tensor_1X1X6 = Int32Tensor({1, 1, 6});
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
|
||||
TFE_Op* reshape_op = TFE_NewOp(tfe_context_, "Reshape", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(reshape_op, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrType(reshape_op, "Tshape", TF_INT32);
|
||||
CheckOutputShapes(reshape_op,
|
||||
/* input_shapes*/ {unknown_shape(), unknown_shape()},
|
||||
/* input_tensors*/ {nullptr, tensor_1X6},
|
||||
/*expected_shape*/ make_shape({1, 6}));
|
||||
TFE_DeleteOp(reshape_op);
|
||||
reshape_op = nullptr;
|
||||
|
||||
TFE_Op* fill_op = TFE_NewOp(tfe_context_, "Fill", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(fill_op, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
|
||||
|
||||
CheckOutputShapes(fill_op,
|
||||
/* input_shapes*/ {unknown_shape(), unknown_shape()},
|
||||
/* input_tensors*/ {tensor_1X1X6, nullptr},
|
||||
/*expected_shape*/ make_shape({1, 1, 6}));
|
||||
TFE_DeleteOp(fill_op);
|
||||
fill_op = nullptr;
|
||||
|
||||
TF_DeleteTensor(tensor_1X1X6);
|
||||
TF_DeleteTensor(tensor_1X6);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user