partial support for list attr in TFE_Op, this unblocks const-folding for ops like depthwise_conv in mlir.

PiperOrigin-RevId: 329238224
Change-Id: If61677796448bab056c104c5dbb264829a0627c1
This commit is contained in:
Renjie Liu 2020-08-30 21:57:26 -07:00 committed by TensorFlower Gardener
parent c0757ec6ed
commit eb1c8b0ff7
3 changed files with 134 additions and 2 deletions

View File

@ -1553,8 +1553,67 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
TFE_OpSetAttrFunction(op, attr_name, func_op); TFE_OpSetAttrFunction(op, attr_name, func_op);
TFE_DeleteOp(func_op); TFE_DeleteOp(func_op);
} break; } break;
case tensorflow::AttrValue::kList: case tensorflow::AttrValue::kList: {
TF_FALLTHROUGH_INTENDED; // String
if (const int s_size = default_value.list().s_size()) {
absl::InlinedVector<const void*, 4> values_vector;
absl::InlinedVector<size_t, 4> lengths_vector;
for (int i = 0; i < s_size; ++i) {
const string& v = default_value.list().s(i);
values_vector.push_back(v.data());
lengths_vector.push_back(v.size());
}
TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
lengths_vector.data(), s_size);
}
// Int
if (const int i_size = default_value.list().i_size()) {
absl::InlinedVector<int64_t, 4> i_vector;
for (int i = 0; i < i_size; ++i) {
i_vector.push_back(default_value.list().i(i));
}
TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
}
// Float
if (const int f_size = default_value.list().f_size()) {
absl::InlinedVector<float, 4> f_vector;
for (int i = 0; i < f_size; ++i) {
f_vector.push_back(default_value.list().f(i));
}
TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
}
// Bool
if (const int b_size = default_value.list().b_size()) {
absl::InlinedVector<unsigned char, 4> b_vector;
for (int i = 0; i < b_size; i++) {
b_vector.push_back(default_value.list().b(i));
}
TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
}
// Type
if (const int type_size = default_value.list().type_size()) {
absl::InlinedVector<unsigned int, 4> type_vector;
for (int i = 0; i < type_size; ++i) {
type_vector.push_back(default_value.list().type(i));
}
TFE_OpSetAttrTypeList(
op, attr_name,
reinterpret_cast<const TF_DataType*>(type_vector.data()),
type_size);
}
// Rest are not supported.
if (default_value.list().shape_size() > 0 ||
default_value.list().func_size() > 0 ||
default_value.list().tensor_size() > 0) {
TF_SetStatus(
status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat("Unable to get setfor default value: ",
default_value.DebugString())
.data());
}
} break;
case tensorflow::AttrValue::kTensor: case tensorflow::AttrValue::kTensor:
TF_FALLTHROUGH_INTENDED; TF_FALLTHROUGH_INTENDED;
case tensorflow::AttrValue::kPlaceholder: case tensorflow::AttrValue::kPlaceholder:

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string> #include <string>
// clang-format off // clang-format off
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
// clang-format on // clang-format on
@ -1191,6 +1192,68 @@ TEST(CAPI, StringAttributes) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
// Same test as above, expect use SetOpAttrValueScalar to set attrs.
TEST(CAPI, TestTFE_SetOpAttrs) {
// Test that TFE_OpSetAttrString doesn't hold on to the value after it
// returns.
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::vector<int64_t> dims(4, 1);
TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* tensor =
TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
float tensor_data[] = {1};
memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, tensor_handle, status);
TF_DeleteTensor(tensor);
TFE_DeleteTensorHandle(tensor_handle);
tensorflow::AttrValue i_list_values;
for (int i = 0; i < 4; ++i) {
i_list_values.mutable_list()->add_i(1);
}
SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status);
SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status);
tensorflow::AttrValue padding_value;
*padding_value.mutable_s() = "VALID";
tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status);
tensorflow::AttrValue data_format_value;
*data_format_value.mutable_s() = "NHWC";
tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format",
status);
TFE_OpSetAttrType(op, "T", TF_FLOAT);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(op, &retvals[0], &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
tensor = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(4, TF_TensorByteSize(tensor));
TF_DeleteTensor(tensor);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(op);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);

View File

@ -492,3 +492,13 @@ func @DontFoldTile() -> (tensor<8x10000xi32>) {
return %3 : tensor<8x10000xi32> return %3 : tensor<8x10000xi32>
} }
// LINT.ThenChange(../transforms/constant_fold.cc:folding-policy) // LINT.ThenChange(../transforms/constant_fold.cc:folding-policy)
func @fold_conv() -> tensor<1x520x520x1xf32> {
%0 = "tf.Const"() {value = dense<0.111111112> : tensor<3x3x1x1xf32>} : () -> tensor<3x3x1x1xf32>
%1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1x520x520x1xf32>} : () -> tensor<1x520x520x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x520x520x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x520x520x1xf32>
return %2 : tensor<1x520x520x1xf32>
// CHECK: tf.Const
// CHECK-NOT: tf.DepthwiseConv2dNative
}