Handle input_tensors in C shape inference API.

PiperOrigin-RevId: 261761496
This commit is contained in:
A. Unique TensorFlower 2019-08-05 13:51:51 -07:00 committed by TensorFlower Gardener
parent 9683994327
commit 038d0cdabc
3 changed files with 74 additions and 14 deletions

View File

@ -1053,6 +1053,10 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
delete[] shape_list_array;
}
namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
} // namespace tensorflow
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
TF_Tensor** input_tensors,
TF_ShapeAndTypeList* input_tensors_as_shapes,
@ -1082,10 +1086,26 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
if (!status->status.ok()) return;
// Initialize a input_tensor vector with `nullptr` values.
std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
// A vector to keep track of newly created `tf::Tensor` objects.
std::vector<Tensor> all_input_tensors;
// Update the vector with information from `input_tensors` if provided.
if (input_tensors != nullptr) {
for (int i = 0; i < num_inputs; ++i) {
if (input_tensors[i] == nullptr) continue;
all_input_tensors.emplace_back();
Tensor& input_tensor = all_input_tensors.back();
status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
if (!status->status.ok()) return;
input_tensors_vector[i] = &input_tensor;
}
}
// Create an inference context with dummy values, which will be updated later.
InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def,
std::vector<ShapeHandle>(num_inputs),
std::vector<const Tensor*>(num_inputs, nullptr), {},
std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
{},
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
// Set input_shapes.
@ -1102,7 +1122,6 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
c.SetInput(i, c.MakeShape(dims));
}
// TODO(bgogul): Handle input_tensors.
// TODO(bgogul): Handle input_tensors_as_shapes.
// TODO(bgogul): Handle input_resource_shapes_and_types.

View File

@ -378,10 +378,15 @@ TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeList(
TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeListArray(
TF_ShapeAndTypeList** shape_list_array, int num_items);
// Infer shapes for the given `node_def`. The arguments mimic the arguments of
// the `shape_inference::InferenceContext` constructor. The types need not be
// set in `input_shapes` as it is not used for shape inference. The number of
// `input_tensors` should be the same as the number of items in `input_shapes`.
// Infer shapes for the given `op`. The arguments mimic the arguments of the
// `shape_inference::InferenceContext` constructor. Note the following:
// - The inputs of the `op` are not used for shape inference. So, it is
// OK to not have the inputs properly set in `op`. See `input_tensors`
// if you want shape inference to consider the input tensors of the
// op for shape inference.
// - The types need not be set in `input_shapes` as it is not used.
// - The number of `input_tensors` should be the same as the number of items
// in `input_shapes`.
//
// The results are returned in `output_shapes` and
// `output_resource_shapes_and_types`. The caller is responsible for freeing the

View File

@ -439,7 +439,6 @@ class ShapeInferenceTest : public ::testing::Test {
: status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) {
tfe_context_ = TFE_NewContext(tfe_context_options_, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
}
~ShapeInferenceTest() override {
@ -452,7 +451,7 @@ class ShapeInferenceTest : public ::testing::Test {
void CheckOutputShapes(
TFE_Op* op,
const std::vector<absl::optional<std::vector<int64_t>>>& input_shapes_vec,
TF_Tensor** input_tensors,
const std::vector<TF_Tensor*>& input_tensors,
const absl::optional<std::vector<int64_t>>& expected_shape) {
// Create input_shapes.
TF_ShapeAndTypeList* input_shapes =
@ -467,7 +466,10 @@ class ShapeInferenceTest : public ::testing::Test {
}
}
TF_ShapeAndTypeList* output_shapes;
TFE_InferShapes(op, input_shapes, input_tensors,
TFE_InferShapes(op, input_shapes,
input_tensors.empty()
? nullptr
: const_cast<TF_Tensor**>(input_tensors.data()),
/*input_tensors_as_shapes*/ nullptr,
/*input_resource_shapes_and_types*/ nullptr, &output_shapes,
/*output_resource_shapes_and_types*/ nullptr, status_);
@ -515,31 +517,65 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) {
// Infer shape when everything is known.
CheckOutputShapes(matmul_op,
/*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})},
/*input_tensors*/ nullptr,
/*input_tensors*/ {},
/*expected_shape*/ make_shape({3, 4}));
// Infer shape when second operand has unknown shape.
CheckOutputShapes(matmul_op,
/*input_shapes*/ {make_shape({3, 2}), unknown_shape()},
/*input_tensors*/ nullptr,
/*input_tensors*/ {},
/*expected_shape*/ make_shape({3, kUnknownDim}));
// Infer shape when some dimensions are unknown.
CheckOutputShapes(
matmul_op,
/*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})},
/*input_tensors*/ nullptr,
/*input_tensors*/ {},
/*expected_shape*/ make_shape({kUnknownDim, 4}));
// Infer shape when everything is unknown.
CheckOutputShapes(matmul_op,
/*input_shapes*/ {unknown_shape(), unknown_shape()},
/*input_tensors*/ nullptr,
/*input_tensors*/ {},
/*expected_shape*/ make_shape({kUnknownDim, kUnknownDim}));
TFE_DeleteOp(matmul_op);
// TODO(bgogul): Add some death tests where status is not OK.
}
TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
// Prepare some tensors for shape.
TF_Tensor* tensor_1X6 = Int32Tensor({1, 6});
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TF_Tensor* tensor_1X1X6 = Int32Tensor({1, 1, 6});
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_Op* reshape_op = TFE_NewOp(tfe_context_, "Reshape", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(reshape_op, "T", TF_FLOAT);
TFE_OpSetAttrType(reshape_op, "Tshape", TF_INT32);
CheckOutputShapes(reshape_op,
/* input_shapes*/ {unknown_shape(), unknown_shape()},
/* input_tensors*/ {nullptr, tensor_1X6},
/*expected_shape*/ make_shape({1, 6}));
TFE_DeleteOp(reshape_op);
reshape_op = nullptr;
TFE_Op* fill_op = TFE_NewOp(tfe_context_, "Fill", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(fill_op, "T", TF_FLOAT);
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
CheckOutputShapes(fill_op,
/* input_shapes*/ {unknown_shape(), unknown_shape()},
/* input_tensors*/ {tensor_1X1X6, nullptr},
/*expected_shape*/ make_shape({1, 1, 6}));
TFE_DeleteOp(fill_op);
fill_op = nullptr;
TF_DeleteTensor(tensor_1X1X6);
TF_DeleteTensor(tensor_1X6);
}
} // namespace
} // namespace tensorflow