[C API/Eager]: Fix bug in TFE_OpSetAttrString.
TFE_OpSetAttrString was holding on to the 'value' pointer
after it returned. This bug was introduced in commit
2b0805301e
which caused TFE_OpSetAttrString to invoke
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
instead of:
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, T&& value)
(where the latter copies 'value' when T is a StringPiece or const char*
and the former aliases the memory pointed to by StringPiece).
In this process, I realized that
AttrBuilder::Set(StringPiece attr_name, StringPiece&& value)
was never being invoked (other than in this buggy situation),
so I removed it altogether.
Without the changes to attr_builder.{h,cc}, the newly added test
fails - complaining that "NHWC" is not a valid value for the "padding"
attribute.
PiperOrigin-RevId: 209017110
This commit is contained in:
parent
e77dbdb205
commit
72c5efff88
tensorflow
@ -1471,4 +1471,61 @@ void BM_ReadVariable(int iters) {
|
||||
}
|
||||
BENCHMARK(BM_ReadVariable);
|
||||
|
||||
TEST(CAPI, StringAttributes) {
|
||||
// Test that TFE_OpSetAttrString doesn't hold on to the value after it
|
||||
// returns.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::vector<int64_t> dims(4, 1);
|
||||
TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Tensor* tensor =
|
||||
TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
|
||||
float tensor_data[] = {1};
|
||||
memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
|
||||
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, tensor_handle, status);
|
||||
TF_DeleteTensor(tensor);
|
||||
TFE_DeleteTensorHandle(tensor_handle);
|
||||
|
||||
std::vector<int64_t> values(4, 1);
|
||||
TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size());
|
||||
TFE_OpSetAttrIntList(op, "strides", values.data(), values.size());
|
||||
|
||||
const int BUFFER_SIZE = 10;
|
||||
char buffer[BUFFER_SIZE];
|
||||
std::strncpy(buffer, "VALID", BUFFER_SIZE);
|
||||
TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer));
|
||||
// Overwriting value in "buffer", should be fine since TFE_Op
|
||||
// shouldn't be holding on to it.
|
||||
std::strncpy(buffer, "NHWC", BUFFER_SIZE);
|
||||
TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer));
|
||||
|
||||
TFE_OpSetAttrType(op, "T", TF_FLOAT);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &retvals[0], &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
|
||||
tensor = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(4, TF_TensorByteSize(tensor));
|
||||
TF_DeleteTensor(tensor);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -103,7 +103,6 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
|
||||
return *this; \
|
||||
}
|
||||
|
||||
DEFINE_SET_ATTR(StringPiece, string_attrs_);
|
||||
DEFINE_SET_ATTR(float, float_attrs_);
|
||||
DEFINE_SET_ATTR(int, int_attrs_);
|
||||
DEFINE_SET_ATTR(bool, bool_attrs_);
|
||||
@ -119,9 +118,6 @@ AttrBuilder& AttrBuilder::NumInputs(int n) {
|
||||
|
||||
void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
|
||||
bool include_those_in_node_def) const {
|
||||
for (const auto& p : string_attrs_) {
|
||||
SetInAttrValueMap(m, p.first, p.second);
|
||||
}
|
||||
for (const auto& p : int_attrs_) {
|
||||
SetInAttrValueMap(m, p.first, p.second);
|
||||
}
|
||||
@ -211,10 +207,6 @@ tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
|
||||
// not been called.
|
||||
if (node_def_finalized_) return f;
|
||||
}
|
||||
for (const auto& p : string_attrs_) {
|
||||
CombineUnordered(
|
||||
CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
|
||||
}
|
||||
for (const auto& p : int_attrs_) {
|
||||
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
|
||||
&f);
|
||||
|
@ -131,7 +131,6 @@ class AttrBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
AttrVec<StringPiece> string_attrs_;
|
||||
AttrVec<int> int_attrs_;
|
||||
AttrVec<float> float_attrs_;
|
||||
AttrVec<bool> bool_attrs_;
|
||||
@ -142,8 +141,6 @@ class AttrBuilder {
|
||||
bool node_def_finalized_;
|
||||
}; // namespace tensorflow
|
||||
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
|
||||
template <>
|
||||
|
Loading…
Reference in New Issue
Block a user