diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 2920c94b4c2..6528d21dc36 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -1553,8 +1553,67 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, TFE_OpSetAttrFunction(op, attr_name, func_op); TFE_DeleteOp(func_op); } break; - case tensorflow::AttrValue::kList: - TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kList: { + // String + if (const int s_size = default_value.list().s_size()) { + absl::InlinedVector<const void*, 4> values_vector; + absl::InlinedVector<size_t, 4> lengths_vector; + for (int i = 0; i < s_size; ++i) { + const string& v = default_value.list().s(i); + values_vector.push_back(v.data()); + lengths_vector.push_back(v.size()); + } + TFE_OpSetAttrStringList(op, attr_name, values_vector.data(), + lengths_vector.data(), s_size); + } + + // Int + if (const int i_size = default_value.list().i_size()) { + absl::InlinedVector<int64_t, 4> i_vector; + for (int i = 0; i < i_size; ++i) { + i_vector.push_back(default_value.list().i(i)); + } + TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size); + } + // Float + if (const int f_size = default_value.list().f_size()) { + absl::InlinedVector<float, 4> f_vector; + for (int i = 0; i < f_size; ++i) { + f_vector.push_back(default_value.list().f(i)); + } + TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size); + } + // Bool + if (const int b_size = default_value.list().b_size()) { + absl::InlinedVector<unsigned char, 4> b_vector; + for (int i = 0; i < b_size; i++) { + b_vector.push_back(default_value.list().b(i)); + } + TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size); + } + // Type + if (const int type_size = default_value.list().type_size()) { + absl::InlinedVector<unsigned int, 4> type_vector; + for (int i = 0; i < type_size; ++i) { + type_vector.push_back(default_value.list().type(i)); + } + TFE_OpSetAttrTypeList( + op, attr_name, + reinterpret_cast<const TF_DataType*>(type_vector.data()), + type_size); + } + + // Rest are not supported. + if (default_value.list().shape_size() > 0 || + default_value.list().func_size() > 0 || + default_value.list().tensor_size() > 0) { + TF_SetStatus( + status, TF_UNIMPLEMENTED, + tensorflow::strings::StrCat("Unable to get setfor default value: ", + default_value.DebugString()) + .data()); + } + } break; case tensorflow::AttrValue::kTensor: TF_FALLTHROUGH_INTENDED; case tensorflow::AttrValue::kPlaceholder: diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 37bb9c5f16b..fd208c6770d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include <string> // clang-format off +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/platform/platform.h" // clang-format on @@ -1191,6 +1192,68 @@ TEST(CAPI, StringAttributes) { TF_DeleteStatus(status); } +// Same test as above, expect use SetOpAttrValueScalar to set attrs. +TEST(CAPI, TestTFE_SetOpAttrs) { + // Test that TFE_OpSetAttrString doesn't hold on to the value after it + // returns. + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::vector<int64_t> dims(4, 1); + TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* tensor = + TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float)); + float tensor_data[] = {1}; + memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor)); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, tensor_handle, status); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(tensor_handle); + + tensorflow::AttrValue i_list_values; + for (int i = 0; i < 4; ++i) { + i_list_values.mutable_list()->add_i(1); + } + SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status); + SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status); + + tensorflow::AttrValue padding_value; + *padding_value.mutable_s() = "VALID"; + tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status); + + tensorflow::AttrValue data_format_value; + *data_format_value.mutable_s() = "NHWC"; + tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format", + status); + + TFE_OpSetAttrType(op, "T", TF_FLOAT); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(op, &retvals[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + tensor = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(4, TF_TensorByteSize(tensor)); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(op); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( TF_NewStatus(), TF_DeleteStatus); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index f114d1724f2..fff985efa6f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -492,3 +492,13 @@ func @DontFoldTile() -> (tensor<8x10000xi32>) { return %3 : tensor<8x10000xi32> } // LINT.ThenChange(../transforms/constant_fold.cc:folding-policy) + +func @fold_conv() -> tensor<1x520x520x1xf32> { + %0 = "tf.Const"() {value = dense<0.111111112> : tensor<3x3x1x1xf32>} : () -> tensor<3x3x1x1xf32> + %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1x520x520x1xf32>} : () -> tensor<1x520x520x1xf32> + %2 = "tf.DepthwiseConv2dNative"(%1, %0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x520x520x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x520x520x1xf32> + return %2 : tensor<1x520x520x1xf32> + + // CHECK: tf.Const + // CHECK-NOT: tf.DepthwiseConv2dNative +}