partial support for list attr in TFE_Op, this unblocks const-folding for ops like depthwise_conv in mlir.
PiperOrigin-RevId: 329238224 Change-Id: If61677796448bab056c104c5dbb264829a0627c1
This commit is contained in:
parent
c0757ec6ed
commit
eb1c8b0ff7
@ -1553,8 +1553,67 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
|||||||
TFE_OpSetAttrFunction(op, attr_name, func_op);
|
TFE_OpSetAttrFunction(op, attr_name, func_op);
|
||||||
TFE_DeleteOp(func_op);
|
TFE_DeleteOp(func_op);
|
||||||
} break;
|
} break;
|
||||||
case tensorflow::AttrValue::kList:
|
case tensorflow::AttrValue::kList: {
|
||||||
TF_FALLTHROUGH_INTENDED;
|
// String
|
||||||
|
if (const int s_size = default_value.list().s_size()) {
|
||||||
|
absl::InlinedVector<const void*, 4> values_vector;
|
||||||
|
absl::InlinedVector<size_t, 4> lengths_vector;
|
||||||
|
for (int i = 0; i < s_size; ++i) {
|
||||||
|
const string& v = default_value.list().s(i);
|
||||||
|
values_vector.push_back(v.data());
|
||||||
|
lengths_vector.push_back(v.size());
|
||||||
|
}
|
||||||
|
TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
|
||||||
|
lengths_vector.data(), s_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int
|
||||||
|
if (const int i_size = default_value.list().i_size()) {
|
||||||
|
absl::InlinedVector<int64_t, 4> i_vector;
|
||||||
|
for (int i = 0; i < i_size; ++i) {
|
||||||
|
i_vector.push_back(default_value.list().i(i));
|
||||||
|
}
|
||||||
|
TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
|
||||||
|
}
|
||||||
|
// Float
|
||||||
|
if (const int f_size = default_value.list().f_size()) {
|
||||||
|
absl::InlinedVector<float, 4> f_vector;
|
||||||
|
for (int i = 0; i < f_size; ++i) {
|
||||||
|
f_vector.push_back(default_value.list().f(i));
|
||||||
|
}
|
||||||
|
TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
|
||||||
|
}
|
||||||
|
// Bool
|
||||||
|
if (const int b_size = default_value.list().b_size()) {
|
||||||
|
absl::InlinedVector<unsigned char, 4> b_vector;
|
||||||
|
for (int i = 0; i < b_size; i++) {
|
||||||
|
b_vector.push_back(default_value.list().b(i));
|
||||||
|
}
|
||||||
|
TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
|
||||||
|
}
|
||||||
|
// Type
|
||||||
|
if (const int type_size = default_value.list().type_size()) {
|
||||||
|
absl::InlinedVector<unsigned int, 4> type_vector;
|
||||||
|
for (int i = 0; i < type_size; ++i) {
|
||||||
|
type_vector.push_back(default_value.list().type(i));
|
||||||
|
}
|
||||||
|
TFE_OpSetAttrTypeList(
|
||||||
|
op, attr_name,
|
||||||
|
reinterpret_cast<const TF_DataType*>(type_vector.data()),
|
||||||
|
type_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rest are not supported.
|
||||||
|
if (default_value.list().shape_size() > 0 ||
|
||||||
|
default_value.list().func_size() > 0 ||
|
||||||
|
default_value.list().tensor_size() > 0) {
|
||||||
|
TF_SetStatus(
|
||||||
|
status, TF_UNIMPLEMENTED,
|
||||||
|
tensorflow::strings::StrCat("Unable to get setfor default value: ",
|
||||||
|
default_value.DebugString())
|
||||||
|
.data());
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case tensorflow::AttrValue::kTensor:
|
case tensorflow::AttrValue::kTensor:
|
||||||
TF_FALLTHROUGH_INTENDED;
|
TF_FALLTHROUGH_INTENDED;
|
||||||
case tensorflow::AttrValue::kPlaceholder:
|
case tensorflow::AttrValue::kPlaceholder:
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
@ -1191,6 +1192,68 @@ TEST(CAPI, StringAttributes) {
|
|||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same test as above, expect use SetOpAttrValueScalar to set attrs.
|
||||||
|
TEST(CAPI, TestTFE_SetOpAttrs) {
|
||||||
|
// Test that TFE_OpSetAttrString doesn't hold on to the value after it
|
||||||
|
// returns.
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
std::vector<int64_t> dims(4, 1);
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_Tensor* tensor =
|
||||||
|
TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
|
||||||
|
float tensor_data[] = {1};
|
||||||
|
memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
|
||||||
|
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(op, tensor_handle, status);
|
||||||
|
TF_DeleteTensor(tensor);
|
||||||
|
TFE_DeleteTensorHandle(tensor_handle);
|
||||||
|
|
||||||
|
tensorflow::AttrValue i_list_values;
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
i_list_values.mutable_list()->add_i(1);
|
||||||
|
}
|
||||||
|
SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status);
|
||||||
|
SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status);
|
||||||
|
|
||||||
|
tensorflow::AttrValue padding_value;
|
||||||
|
*padding_value.mutable_s() = "VALID";
|
||||||
|
tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status);
|
||||||
|
|
||||||
|
tensorflow::AttrValue data_format_value;
|
||||||
|
*data_format_value.mutable_s() = "NHWC";
|
||||||
|
tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format",
|
||||||
|
status);
|
||||||
|
|
||||||
|
TFE_OpSetAttrType(op, "T", TF_FLOAT);
|
||||||
|
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* retvals[1];
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op, &retvals[0], &num_retvals, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
ASSERT_EQ(1, num_retvals);
|
||||||
|
|
||||||
|
tensor = TFE_TensorHandleResolve(retvals[0], status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_EQ(4, TF_TensorByteSize(tensor));
|
||||||
|
TF_DeleteTensor(tensor);
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
@ -492,3 +492,13 @@ func @DontFoldTile() -> (tensor<8x10000xi32>) {
|
|||||||
return %3 : tensor<8x10000xi32>
|
return %3 : tensor<8x10000xi32>
|
||||||
}
|
}
|
||||||
// LINT.ThenChange(../transforms/constant_fold.cc:folding-policy)
|
// LINT.ThenChange(../transforms/constant_fold.cc:folding-policy)
|
||||||
|
|
||||||
|
func @fold_conv() -> tensor<1x520x520x1xf32> {
|
||||||
|
%0 = "tf.Const"() {value = dense<0.111111112> : tensor<3x3x1x1xf32>} : () -> tensor<3x3x1x1xf32>
|
||||||
|
%1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1x520x520x1xf32>} : () -> tensor<1x520x520x1xf32>
|
||||||
|
%2 = "tf.DepthwiseConv2dNative"(%1, %0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x520x520x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x520x520x1xf32>
|
||||||
|
return %2 : tensor<1x520x520x1xf32>
|
||||||
|
|
||||||
|
// CHECK: tf.Const
|
||||||
|
// CHECK-NOT: tf.DepthwiseConv2dNative
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user