diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 68ac81053d9..28c42148f35 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -493,6 +493,22 @@ cc_library( ], ) +tf_cc_test( + name = "convert_type_test", + size = "small", + srcs = ["utils/convert_type_test.cc"], + deps = [ + ":convert_type", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@llvm//:support", + "@local_config_mlir//:IR", + ], +) + cc_library( name = "convert_tensor", srcs = ["utils/convert_tensor.cc"], diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index f536e3885dd..fc7227a2b21 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -26,7 +26,9 @@ limitations under the License. #include "mlir/IR/Identifier.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/OpImplementation.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/SymbolTable.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir namespace mlir { @@ -48,6 +50,22 @@ static bool IsStrArrayAttr(Attribute attr) { // TensorFlowSavedModelDialect Op's //===----------------------------------------------------------------------===// +LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) { + if (!t1.isa() || !t2.isa()) { + return failure(); + } + return verifyCompatibleShape(t1.cast(), t2.cast()); +} + +static LogicalResult Verify(GlobalTensorOp global_tensor) { + if (failed(VerifyTensorTypesCompatible( + global_tensor.type(), global_tensor.value().Attribute::getType()))) { + return global_tensor.emitError() << "'type' and 'value' attributes should " + "have compatible tensor types"; + } + return success(); +} + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td index 59b0b7dba00..b08a83d28f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td @@ -111,13 +111,23 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> { The `value` attribute contains the tensor's value (or initial value, in the case it is mutable). + + The `type` attribute contains the tensor's type, which for the case of + mutable tensors might be more general than just the fixed static shape of + the `value` attribute. For example, a global tensor might be unranked such + as `tensor<*xf32>`, or a more complex shape such as `tensor<4x?x27xf32>`. + The shape of `value` must be compatible with the shape of `type` in the + sense of `tf.TensorShape` compatibility. And the element types must match. }]; let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value, + TypeAttr:$type, UnitAttr:$is_mutable ); + + let verifier = [{ return Verify(*this); }]; } #endif // SAVED_MODEL_DIALECT diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py index 279af55079b..fce0981e1ec 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py @@ -31,8 +31,8 @@ class TestModule(tf.Module): self.v42 = tf.Variable(42.0) self.c43 = tf.constant(43.0) - # CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], value = dense<4.200000e+01> : tensor} : () -> () - # CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], value = dense<4.300000e+01> : tensor} : () -> () + # CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], type = tensor, value = dense<4.200000e+01> : tensor} : () -> () + # CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], type = tensor, value = dense<4.300000e+01> : tensor} : () -> () # CHECK: func {{@[a-zA-Z_0-9]+}}( # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, # CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @[[VAR]]}, diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py new file mode 100644 index 00000000000..d5797ab6a44 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py @@ -0,0 +1,38 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/partially_shaped_variables | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common + + +class TestModule(tf.Module): + + def __init__(self): + super(TestModule, self).__init__() + # CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v0"], type = tensor<*xf32>, value = dense<0.000000e+00> : tensor<1xf32>} : () -> () + # CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v1"], type = tensor, value = dense<0.000000e+00> : tensor<1xf32>} : () -> () + self.v0 = tf.Variable([0.], shape=tf.TensorShape(None)) + self.v1 = tf.Variable([0.], shape=[None]) + + +if __name__ == '__main__': + common.do_test(TestModule, exported_names=[]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir index 2ed538f9c1a..5f1e96430b5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir @@ -5,7 +5,7 @@ module attributes {tf_saved_model.semantics} { // Test case: Simple case of inlining. // CHECK-NOT: tf_saved_model.global_tensor - "tf_saved_model.global_tensor"() { sym_name = "c", value = dense<1.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "c", type = tensor, value = dense<1.0> : tensor } : () -> () // CHECK: func @f() func @f(%arg0: tensor {tf_saved_model.bound_input = @c}) @@ -23,7 +23,7 @@ module attributes {tf_saved_model.semantics} { // Test case: Do not inline mutable global tensors. // CHECK: tf_saved_model.global_tensor - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<1.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<1.0> : tensor } : () -> () // CHECK: func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir index 9d9cc683f81..ea2b383f3bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir @@ -7,6 +7,7 @@ module attributes {tf_saved_model.semantics} { "tf_saved_model.global_tensor"() { tf_saved_model.exported_names = ["some_const"], sym_name = "some_constant", + type = tensor<1x64xf32>, value = dense<42.0> : tensor<1x64xf32> } : () -> () @@ -16,6 +17,7 @@ module attributes {tf_saved_model.semantics} { is_mutable, tf_saved_model.exported_names = ["some_var", "some.other.name"], sym_name = "some_variable", + type = tensor, value = dense<42.0> : tensor<1x64xf32> } : () -> () diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index c562b968194..5507e0e0115 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -133,7 +133,7 @@ module attributes {tf_saved_model.semantics} { module attributes {tf_saved_model.semantics} { - "tf_saved_model.global_tensor"() { sym_name = "some_constant", value = dense<42.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "some_constant", type = tensor, value = dense<42.0> : tensor } : () -> () // expected-error@+1 {{all 'tf_saved_model.index_path' arg attributes should precede all 'tf_saved_model.bound_input' arg attributes}} func @f( @@ -205,3 +205,10 @@ module attributes {tf_saved_model.semantics} { } } + +// ----- + +module attributes {tf_saved_model.semantics} { + // expected-error@+1 {{'type' and 'value' attributes should have compatible tensor types}} + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v0", type = tensor<3xf32>, value = dense<42.0> : tensor<9xf32> } : () -> () +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir index 58d16d97c71..95b0bd54d70 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir @@ -11,7 +11,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: "tf_saved_model.global_tensor"() { // CHECK-NOT: is_mutable // CHECK-SAME: } : () -> () - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor {tf_saved_model.index_path = []}) @@ -33,7 +33,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: "tf_saved_model.global_tensor"() { // CHECK-SAME: is_mutable // CHECK-SAME: } : () -> () - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) @@ -55,7 +55,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: "tf_saved_model.global_tensor"() { // CHECK: is_mutable // CHECK-SAME: } : () -> () - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor {tf_saved_model.index_path = []}) @@ -88,7 +88,7 @@ module attributes {tf_saved_model.semantics} { // Test case: Check that an immutable bound input isn't modified. - "tf_saved_model.global_tensor"() { sym_name = "c", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "c", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @h(%arg0: tensor {tf_saved_model.bound_input = @c}) func @h(%arg0: tensor {tf_saved_model.bound_input = @c}) -> (tensor {tf_saved_model.index_path = []}) @@ -109,13 +109,13 @@ module attributes {tf_saved_model.semantics} { // Test case: Check that an exported global tensor that isn't bound to an // argument is not erased. - "tf_saved_model.global_tensor"() { sym_name = "exported_unbound", tf_saved_model.exported_names = ["exported_unbound"], value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "exported_unbound", tf_saved_model.exported_names = ["exported_unbound"], type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: sym_name = "exported_unbound" // Test case: Check that a global tensor that isn't even bound to an argument // is erased. - "tf_saved_model.global_tensor"() { sym_name = "unexported_unbound", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "unexported_unbound", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK-NOT: sym_name = "unexported_unbound" } @@ -131,7 +131,7 @@ module attributes {tf_saved_model.semantics} { // We erase the argument that this global tensor is bound to, so we delete // the global tensor too. // CHECK-NOT: tf_saved_model.global_tensor - "tf_saved_model.global_tensor"() { sym_name = "c", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "c", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @f() func @f(%arg0: tensor {tf_saved_model.bound_input = @c}) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index a3880827a97..77d3fc4bbca 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -84,6 +84,7 @@ limitations under the License. #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/struct.pb.h" #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" static inline absl::string_view StringRefToView(llvm::StringRef ref) { return {ref.data(), ref.size()}; @@ -2341,11 +2342,16 @@ Status CreateSavedModelIR( TF_ASSIGN_OR_RETURN( Tensor value, ReadVariableFromSession(saved_model, variable.name())); TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); - + // A variable can have a partially known type, such as tensor, + // even if the initializer is a specific static shape. + TF_ASSIGN_OR_RETURN( + auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(), + &builder)); auto op = builder.create( builder.getUnknownLoc(), builder.getStringAttr(object_names.GetSymbolTableName(node_id)), value_attr, + /*type=*/mlir::TypeAttr::get(type), /*is_mutable=*/builder.getUnitAttr()); op.setAttr( "tf_saved_model.exported_names", @@ -2360,6 +2366,7 @@ Status CreateSavedModelIR( builder.getUnknownLoc(), builder.getStringAttr(object_names.GetSymbolTableName(node_id)), value_attr, + /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()), /*is_mutable=*/nullptr); op.setAttr( "tf_saved_model.exported_names", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index fef1ca4a551..1c1f9803bd7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -52,27 +52,6 @@ using mlir::ShapedType; using mlir::Type; using tensorflow::errors::InvalidArgument; -void ConvertToMlirShape(const TensorShape& input_shape, - llvm::SmallVectorImpl* shape) { - shape->reserve(input_shape.dims()); - for (const auto& d : input_shape) { - shape->push_back(d.size); - } -} - -Status ConvertToMlirShape(const TensorShapeProto& input_shape, - llvm::SmallVectorImpl* shape) { - shape->reserve(input_shape.dim_size()); - auto& dims = input_shape.dim(); - for (auto& d : dims) { - if (d.size() > std::numeric_limits::max()) { - return InvalidArgument("Shape element overflows"); - } - shape->push_back(d.size()); - } - return Status::OK(); -} - static TensorProto ConvertToProto(const Tensor& input_tensor, bool use_tensor_content = true) { TensorProto tensor_proto; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index f57cd1c872a..f4b6f702cef 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -29,10 +29,6 @@ namespace tensorflow { using stream_executor::port::StatusOr; -// Converts an TensorFlow shape proto to the one used in MLIR. -Status ConvertToMlirShape(const TensorShapeProto& input_shape, - llvm::SmallVectorImpl* shape); - // Converts an TensorFlow tensor proto into an MLIR elements attribute. StatusOr ConvertTensorProto(const TensorProto& input_tensor, mlir::Builder* builder); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index a78cf12f096..e2d970c8dfd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -31,7 +31,7 @@ using mlir::Builder; using mlir::ShapedType; using mlir::Type; -Status ConvertDataType(const DataType& dtype, Builder builder, Type* type) { +Status ConvertDataType(DataType dtype, Builder builder, Type* type) { switch (dtype) { case DT_HALF: *type = builder.getF16Type(); @@ -149,4 +149,38 @@ Status ConvertToDataType(Type type, DataType* dtype) { return Status::OK(); } +void ConvertToMlirShape(const TensorShape& input_shape, + llvm::SmallVectorImpl* shape) { + shape->reserve(input_shape.dims()); + for (const auto& d : input_shape) { + shape->push_back(d.size); + } +} + +Status ConvertToMlirShape(const TensorShapeProto& input_shape, + llvm::SmallVectorImpl* shape) { + shape->reserve(input_shape.dim_size()); + auto& dims = input_shape.dim(); + for (auto& d : dims) { + if (d.size() > std::numeric_limits::max()) { + return errors::InvalidArgument("Shape element overflows"); + } + shape->push_back(d.size()); + } + return Status::OK(); +} + +StatusOr ConvertToMlirTensorType(const TensorShapeProto& shape, + DataType dtype, + mlir::Builder* builder) { + mlir::Type element_type; + TF_RETURN_IF_ERROR(ConvertDataType(dtype, *builder, &element_type)); + if (shape.unknown_rank()) { + return mlir::UnrankedTensorType::get(element_type); + } + llvm::SmallVector shape_dims; + TF_RETURN_IF_ERROR(ConvertToMlirShape(shape, &shape_dims)); + return mlir::RankedTensorType::get(shape_dims, element_type); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h index 369da7cc480..fa5c92c12fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h @@ -18,14 +18,17 @@ limitations under the License. #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { +using stream_executor::port::StatusOr; + // Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type. -Status ConvertDataType(const DataType& dtype, mlir::Builder builder, - mlir::Type* type); +Status ConvertDataType(DataType dtype, mlir::Builder builder, mlir::Type* type); // Converts a scalar MLIR type to a TensorFlow Datatype. Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype); @@ -34,6 +37,19 @@ Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype); // is converted directly. If it is a shaped type, the element type is converted. Status ConvertToDataType(mlir::Type type, DataType* dtype); +// Converts an TensorFlow shape to the one used in MLIR. +void ConvertToMlirShape(const TensorShape& input_shape, + llvm::SmallVectorImpl* shape); + +// Converts an TensorFlow shape proto to the one used in MLIR. +Status ConvertToMlirShape(const TensorShapeProto& input_shape, + llvm::SmallVectorImpl* shape); + +// Given a tensor shape and dtype, get the corresponding MLIR tensor type. +StatusOr ConvertToMlirTensorType(const TensorShapeProto& shape, + DataType dtype, + mlir::Builder* builder); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc new file mode 100644 index 00000000000..423d61dc2c6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace { + +std::string ConvertToMlirString(const std::vector& dims, + bool unknown_rank, DataType dtype) { + TensorShapeProto shape; + shape.set_unknown_rank(unknown_rank); + for (int64_t dim : dims) { + shape.add_dim()->set_size(dim); + } + mlir::MLIRContext context; + mlir::Builder b(&context); + auto status_or = ConvertToMlirTensorType(shape, dtype, &b); + std::string buf; + llvm::raw_string_ostream os(buf); + status_or.ValueOrDie().print(os); + return os.str(); +} + +TEST(MlirConvertType, ConvertToMlirTensorType) { + // Simple case of static shapes. + EXPECT_EQ("tensor<4x8x16xi32>", + ConvertToMlirString({4, 8, 16}, /*unknown_rank=*/false, + DataType::DT_INT32)); + + // Partially known shapes. + EXPECT_EQ("tensor", + ConvertToMlirString({-1, 27, -1}, /*unknown_rank=*/false, + DataType::DT_BFLOAT16)); + + // Unranked shapes. + EXPECT_EQ("tensor<*xf32>", + ConvertToMlirString({}, /*unknown_rank=*/true, DataType::DT_FLOAT)); +} + +} // namespace + +} // namespace tensorflow