Merge pull request #26592 from karllessard:eager-name-ranges
PiperOrigin-RevId: 239507640
This commit is contained in:
commit
ef5a9f03dd
@ -1527,7 +1527,7 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
auto iter = name_ranges.find(arg_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = InvalidArgument("Input arg '", arg_name, "' not found");
|
||||
status->status = InvalidArgument("Output arg '", arg_name, "' not found");
|
||||
return -1;
|
||||
}
|
||||
return iter->second.second - iter->second.first;
|
||||
|
@ -63,6 +63,17 @@ using tensorflow::int64;
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
||||
if (op->inference_ctx) {
|
||||
return op->inference_ctx->op_def;
|
||||
}
|
||||
const tensorflow::OpDef* op_def;
|
||||
status->status =
|
||||
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
|
||||
return op_def;
|
||||
}
|
||||
|
||||
bool IsCPU(const tensorflow::Device* d) {
|
||||
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
|
||||
}
|
||||
@ -807,6 +818,54 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
funcs.get(), num_values));
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||
const char* input_name,
|
||||
TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
tensorflow::AttrValueMap attrs;
|
||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||
tensorflow::NameRangeMap name_ranges;
|
||||
status->status = tensorflow::NameRangesForNode(
|
||||
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
auto iter = name_ranges.find(input_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
|
||||
"' not found");
|
||||
return -1;
|
||||
}
|
||||
return iter->second.second - iter->second.first;
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
const char* output_name,
|
||||
TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
tensorflow::AttrValueMap attrs;
|
||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||
tensorflow::NameRangeMap name_ranges;
|
||||
status->status = tensorflow::NameRangesForNode(
|
||||
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
auto iter = name_ranges.find(output_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Output '", output_name, "' not found");
|
||||
return -1;
|
||||
}
|
||||
return iter->second.second - iter->second.first;
|
||||
}
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
|
@ -367,6 +367,18 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op,
|
||||
const TFE_Op** value,
|
||||
int num_values);
|
||||
|
||||
// Returns the length (number of tensors) of the input argument `input_name`
|
||||
// found in the provided `op`.
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||
const char* input_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns the length (number of tensors) of the output argument `output_name`
|
||||
// found in the provided `op`.
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
const char* output_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Execute the operation defined by 'op' and return handles to computed
|
||||
// tensors in `retvals`.
|
||||
//
|
||||
|
@ -1781,4 +1781,78 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
TFE_DeleteTensorHandle(dim);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Try to retrieve lengths before building the attributes (should fail)
|
||||
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status));
|
||||
CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
|
||||
CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInputList(identityOp, inputs, 2, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Try to retrieve lengths before executing the op (should work)
|
||||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[2] = {nullptr};
|
||||
int num_retvals = 2;
|
||||
TFE_Execute(identityOp, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Try to retrieve lengths after executing the op (should work)
|
||||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(identityOp);
|
||||
TFE_DeleteTensorHandle(input1);
|
||||
TFE_DeleteTensorHandle(input2);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(retvals[1]);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInputList(identityOp, inputs, 2, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "cheese", status));
|
||||
CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "cheese", status));
|
||||
CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(identityOp);
|
||||
TFE_DeleteTensorHandle(input1);
|
||||
TFE_DeleteTensorHandle(input2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -555,15 +555,14 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
|
||||
|
||||
namespace { // Helpers for NameRangesForNode()
|
||||
|
||||
Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
|
||||
Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
|
||||
const OpDef& op_def, int* num) {
|
||||
if (!arg_def.number_attr().empty()) {
|
||||
// Same type repeated "num" times.
|
||||
return GetNodeAttr(node_def, arg_def.number_attr(), num);
|
||||
return GetNodeAttr(attrs, arg_def.number_attr(), num);
|
||||
} else if (!arg_def.type_list_attr().empty()) {
|
||||
const AttrValue* attr_value;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
|
||||
TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
|
||||
*num = attr_value->list().type_size();
|
||||
} else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
|
||||
*num = 1;
|
||||
@ -575,13 +574,13 @@ Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NameRangesHelper(const NodeDef& node_def,
|
||||
Status NameRangesHelper(const AttrSlice& attrs,
|
||||
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
|
||||
const OpDef& op_def, NameRangeMap* result) {
|
||||
int start = 0;
|
||||
int num;
|
||||
for (const auto& arg : args) {
|
||||
TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num));
|
||||
TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
|
||||
(*result)[arg.name()] = std::make_pair(start, start + num);
|
||||
start += num;
|
||||
}
|
||||
@ -590,14 +589,14 @@ Status NameRangesHelper(const NodeDef& node_def,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
|
||||
Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
|
||||
NameRangeMap* inputs, NameRangeMap* outputs) {
|
||||
if (inputs != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs));
|
||||
NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
|
||||
}
|
||||
if (outputs != nullptr) {
|
||||
return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs);
|
||||
return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -294,7 +294,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
|
||||
// returned `NameRangeMap` objects.
|
||||
typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
|
||||
NameRangeMap;
|
||||
Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
|
||||
Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
|
||||
NameRangeMap* inputs, NameRangeMap* outputs);
|
||||
Status NameRangesForNode(const Node& node, const OpDef& op_def,
|
||||
NameRangeMap* inputs, NameRangeMap* outputs);
|
||||
|
Loading…
Reference in New Issue
Block a user