Merge pull request #26592 from karllessard:eager-name-ranges

PiperOrigin-RevId: 239507640
This commit is contained in:
TensorFlower Gardener 2019-03-20 17:24:01 -07:00
commit ef5a9f03dd
6 changed files with 155 additions and 11 deletions

View File

@ -1527,7 +1527,7 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
if (TF_GetCode(status) != TF_OK) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
status->status = InvalidArgument("Input arg '", arg_name, "' not found");
status->status = InvalidArgument("Output arg '", arg_name, "' not found");
return -1;
}
return iter->second.second - iter->second.first;

View File

@ -63,6 +63,17 @@ using tensorflow::int64;
using tensorflow::string;
namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
if (op->inference_ctx) {
return op->inference_ctx->op_def;
}
const tensorflow::OpDef* op_def;
status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def;
}
bool IsCPU(const tensorflow::Device* d) {
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
@ -807,6 +818,54 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
funcs.get(), num_values));
}
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
"' not found");
return -1;
}
return iter->second.second - iter->second.first;
}
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument(
"Output '", output_name, "' not found");
return -1;
}
return iter->second.second - iter->second.first;
}
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
VLOG(1) << "Calling TFE_Execute() on op " << op;

View File

@ -367,6 +367,18 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op,
const TFE_Op** value,
int num_values);
// Returns the length (number of tensors) of the input argument `input_name`
// found in the provided `op`.
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status);
// Returns the length (number of tensors) of the output argument `output_name`
// found in the provided `op`.
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status);
// Execute the operation defined by 'op' and return handles to computed
// tensors in `retvals`.
//

View File

@ -1781,4 +1781,78 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_DeleteTensorHandle(dim);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Try to retrieve lengths before building the attributes (should fail)
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Try to retrieve lengths before executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[2] = {nullptr};
int num_retvals = 2;
TFE_Execute(identityOp, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Try to retrieve lengths after executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteOp(identityOp);
TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(retvals[1]);
}
TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "cheese", status));
CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "cheese", status));
CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteOp(identityOp);
TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2);
}
} // namespace

View File

@ -555,15 +555,14 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
namespace { // Helpers for NameRangesForNode()
Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
const OpDef& op_def, int* num) {
if (!arg_def.number_attr().empty()) {
// Same type repeated "num" times.
return GetNodeAttr(node_def, arg_def.number_attr(), num);
return GetNodeAttr(attrs, arg_def.number_attr(), num);
} else if (!arg_def.type_list_attr().empty()) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(
AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
*num = attr_value->list().type_size();
} else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
*num = 1;
@ -575,13 +574,13 @@ Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
return Status::OK();
}
Status NameRangesHelper(const NodeDef& node_def,
Status NameRangesHelper(const AttrSlice& attrs,
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
const OpDef& op_def, NameRangeMap* result) {
int start = 0;
int num;
for (const auto& arg : args) {
TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num));
TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
(*result)[arg.name()] = std::make_pair(start, start + num);
start += num;
}
@ -590,14 +589,14 @@ Status NameRangesHelper(const NodeDef& node_def,
} // namespace
Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
NameRangeMap* inputs, NameRangeMap* outputs) {
if (inputs != nullptr) {
TF_RETURN_IF_ERROR(
NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs));
NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
}
if (outputs != nullptr) {
return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs);
return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
}
return Status::OK();
}

View File

@ -294,7 +294,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
// returned `NameRangeMap` objects.
typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
NameRangeMap;
Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
NameRangeMap* inputs, NameRangeMap* outputs);
Status NameRangesForNode(const Node& node, const OpDef& op_def,
NameRangeMap* inputs, NameRangeMap* outputs);