From 9182431be6b9e64e271c0e9fc4a57289166ab2b1 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 19 Nov 2019 10:34:58 -0800 Subject: [PATCH] tf_saved_model: Add explicit tracking of variable partial types. Variables in TF can change type, within the bounds of the partial type that they are given at creation time. This CL adds another field to tf_saved_model.global_tensor to explicitly track that partial type. For example, a variable can be created with partial shape `tensor<*xf32>`, and then it can be assigned any compatible type (i.e. any f32 tensor) at any point in the program. The same applies for `tensor`, but the types that are legal to assign to it are now more restricted. This CL also refactors some of the utilities a little bit. PiperOrigin-RevId: 281325922 Change-Id: Ifa04f827670681839574432b9d9d0ba2c0ee5f79 --- tensorflow/compiler/mlir/tensorflow/BUILD | 16 +++++ .../mlir/tensorflow/ir/tf_saved_model.cc | 18 ++++++ .../mlir/tensorflow/ir/tf_saved_model_ops.td | 10 +++ .../tensorflow/tests/tf_saved_model/basic.py | 4 +- .../partially_shaped_variables.py | 38 +++++++++++ .../tf_saved_model_inline_global_tensors.mlir | 4 +- .../tensorflow/tests/tf_saved_model_ops.mlir | 2 + .../tests/tf_saved_model_ops_invalid.mlir | 9 ++- ...f_saved_model_optimize_global_tensors.mlir | 14 ++-- .../mlir/tensorflow/translate/import_model.cc | 9 ++- .../mlir/tensorflow/utils/convert_tensor.cc | 21 ------ .../mlir/tensorflow/utils/convert_tensor.h | 4 -- .../mlir/tensorflow/utils/convert_type.cc | 36 ++++++++++- .../mlir/tensorflow/utils/convert_type.h | 20 +++++- .../tensorflow/utils/convert_type_test.cc | 64 +++++++++++++++++++ 15 files changed, 228 insertions(+), 41 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 68ac81053d9..28c42148f35 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -493,6 +493,22 @@ cc_library( ], ) +tf_cc_test( + name = "convert_type_test", + size = "small", + srcs = ["utils/convert_type_test.cc"], + deps = [ + ":convert_type", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@llvm//:support", + "@local_config_mlir//:IR", + ], +) + cc_library( name = "convert_tensor", srcs = ["utils/convert_tensor.cc"], diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index f536e3885dd..fc7227a2b21 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -26,7 +26,9 @@ limitations under the License. #include "mlir/IR/Identifier.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/OpImplementation.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/SymbolTable.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir namespace mlir { @@ -48,6 +50,22 @@ static bool IsStrArrayAttr(Attribute attr) { // TensorFlowSavedModelDialect Op's //===----------------------------------------------------------------------===// +LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) { + if (!t1.isa() || !t2.isa()) { + return failure(); + } + return verifyCompatibleShape(t1.cast(), t2.cast()); +} + +static LogicalResult Verify(GlobalTensorOp global_tensor) { + if (failed(VerifyTensorTypesCompatible( + global_tensor.type(), global_tensor.value().Attribute::getType()))) { + return global_tensor.emitError() << "'type' and 'value' attributes should " + "have compatible tensor types"; + } + return success(); +} + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td index 59b0b7dba00..b08a83d28f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td @@ -111,13 +111,23 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> { The `value` attribute contains the tensor's value (or initial value, in the case it is mutable). + + The `type` attribute contains the tensor's type, which for the case of + mutable tensors might be more general than just the fixed static shape of + the `value` attribute. For example, a global tensor might be unranked such + as `tensor<*xf32>`, or a more complex shape such as `tensor<4x?x27xf32>`. + The shape of `value` must be compatible with the shape of `type` in the + sense of `tf.TensorShape` compatibility. And the element types must match. }]; let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value, + TypeAttr:$type, UnitAttr:$is_mutable ); + + let verifier = [{ return Verify(*this); }]; } #endif // SAVED_MODEL_DIALECT diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py index 279af55079b..fce0981e1ec 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py @@ -31,8 +31,8 @@ class TestModule(tf.Module): self.v42 = tf.Variable(42.0) self.c43 = tf.constant(43.0) - # CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], value = dense<4.200000e+01> : tensor} : () -> () - # CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], value = dense<4.300000e+01> : tensor} : () -> () + # CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], type = tensor, value = dense<4.200000e+01> : tensor} : () -> () + # CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], type = tensor, value = dense<4.300000e+01> : tensor} : () -> () # CHECK: func {{@[a-zA-Z_0-9]+}}( # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, # CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @[[VAR]]}, diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py new file mode 100644 index 00000000000..d5797ab6a44 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py @@ -0,0 +1,38 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/partially_shaped_variables | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common + + +class TestModule(tf.Module): + + def __init__(self): + super(TestModule, self).__init__() + # CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v0"], type = tensor<*xf32>, value = dense<0.000000e+00> : tensor<1xf32>} : () -> () + # CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v1"], type = tensor, value = dense<0.000000e+00> : tensor<1xf32>} : () -> () + self.v0 = tf.Variable([0.], shape=tf.TensorShape(None)) + self.v1 = tf.Variable([0.], shape=[None]) + + +if __name__ == '__main__': + common.do_test(TestModule, exported_names=[]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir index 2ed538f9c1a..5f1e96430b5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_inline_global_tensors.mlir @@ -5,7 +5,7 @@ module attributes {tf_saved_model.semantics} { // Test case: Simple case of inlining. // CHECK-NOT: tf_saved_model.global_tensor - "tf_saved_model.global_tensor"() { sym_name = "c", value = dense<1.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "c", type = tensor, value = dense<1.0> : tensor } : () -> () // CHECK: func @f() func @f(%arg0: tensor {tf_saved_model.bound_input = @c}) @@ -23,7 +23,7 @@ module attributes {tf_saved_model.semantics} { // Test case: Do not inline mutable global tensors. // CHECK: tf_saved_model.global_tensor - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<1.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<1.0> : tensor } : () -> () // CHECK: func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir index 9d9cc683f81..ea2b383f3bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir @@ -7,6 +7,7 @@ module attributes {tf_saved_model.semantics} { "tf_saved_model.global_tensor"() { tf_saved_model.exported_names = ["some_const"], sym_name = "some_constant", + type = tensor<1x64xf32>, value = dense<42.0> : tensor<1x64xf32> } : () -> () @@ -16,6 +17,7 @@ module attributes {tf_saved_model.semantics} { is_mutable, tf_saved_model.exported_names = ["some_var", "some.other.name"], sym_name = "some_variable", + type = tensor, value = dense<42.0> : tensor<1x64xf32> } : () -> () diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index c562b968194..5507e0e0115 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -133,7 +133,7 @@ module attributes {tf_saved_model.semantics} { module attributes {tf_saved_model.semantics} { - "tf_saved_model.global_tensor"() { sym_name = "some_constant", value = dense<42.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "some_constant", type = tensor, value = dense<42.0> : tensor } : () -> () // expected-error@+1 {{all 'tf_saved_model.index_path' arg attributes should precede all 'tf_saved_model.bound_input' arg attributes}} func @f( @@ -205,3 +205,10 @@ module attributes {tf_saved_model.semantics} { } } + +// ----- + +module attributes {tf_saved_model.semantics} { + // expected-error@+1 {{'type' and 'value' attributes should have compatible tensor types}} + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v0", type = tensor<3xf32>, value = dense<42.0> : tensor<9xf32> } : () -> () +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir index 58d16d97c71..95b0bd54d70 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir @@ -11,7 +11,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: "tf_saved_model.global_tensor"() { // CHECK-NOT: is_mutable // CHECK-SAME: } : () -> () - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor {tf_saved_model.index_path = []}) @@ -33,7 +33,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: "tf_saved_model.global_tensor"() { // CHECK-SAME: is_mutable // CHECK-SAME: } : () -> () - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) @@ -55,7 +55,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: "tf_saved_model.global_tensor"() { // CHECK: is_mutable // CHECK-SAME: } : () -> () - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor {tf_saved_model.index_path = []}) @@ -88,7 +88,7 @@ module attributes {tf_saved_model.semantics} { // Test case: Check that an immutable bound input isn't modified. - "tf_saved_model.global_tensor"() { sym_name = "c", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "c", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @h(%arg0: tensor {tf_saved_model.bound_input = @c}) func @h(%arg0: tensor {tf_saved_model.bound_input = @c}) -> (tensor {tf_saved_model.index_path = []}) @@ -109,13 +109,13 @@ module attributes {tf_saved_model.semantics} { // Test case: Check that an exported global tensor that isn't bound to an // argument is not erased. - "tf_saved_model.global_tensor"() { sym_name = "exported_unbound", tf_saved_model.exported_names = ["exported_unbound"], value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "exported_unbound", tf_saved_model.exported_names = ["exported_unbound"], type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: sym_name = "exported_unbound" // Test case: Check that a global tensor that isn't even bound to an argument // is erased. - "tf_saved_model.global_tensor"() { sym_name = "unexported_unbound", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "unexported_unbound", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK-NOT: sym_name = "unexported_unbound" } @@ -131,7 +131,7 @@ module attributes {tf_saved_model.semantics} { // We erase the argument that this global tensor is bound to, so we delete // the global tensor too. // CHECK-NOT: tf_saved_model.global_tensor - "tf_saved_model.global_tensor"() { sym_name = "c", value = dense<42.> : tensor } : () -> () + "tf_saved_model.global_tensor"() { sym_name = "c", type = tensor, value = dense<42.> : tensor } : () -> () // CHECK: func @f() func @f(%arg0: tensor {tf_saved_model.bound_input = @c}) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index a3880827a97..77d3fc4bbca 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -84,6 +84,7 @@ limitations under the License. #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/struct.pb.h" #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" static inline absl::string_view StringRefToView(llvm::StringRef ref) { return {ref.data(), ref.size()}; @@ -2341,11 +2342,16 @@ Status CreateSavedModelIR( TF_ASSIGN_OR_RETURN( Tensor value, ReadVariableFromSession(saved_model, variable.name())); TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); - + // A variable can have a partially known type, such as tensor, + // even if the initializer is a specific static shape. + TF_ASSIGN_OR_RETURN( + auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(), + &builder)); auto op = builder.create( builder.getUnknownLoc(), builder.getStringAttr(object_names.GetSymbolTableName(node_id)), value_attr, + /*type=*/mlir::TypeAttr::get(type), /*is_mutable=*/builder.getUnitAttr()); op.setAttr( "tf_saved_model.exported_names", @@ -2360,6 +2366,7 @@ Status CreateSavedModelIR( builder.getUnknownLoc(), builder.getStringAttr(object_names.GetSymbolTableName(node_id)), value_attr, + /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()), /*is_mutable=*/nullptr); op.setAttr( "tf_saved_model.exported_names", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index fef1ca4a551..1c1f9803bd7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -52,27 +52,6 @@ using mlir::ShapedType; using mlir::Type; using tensorflow::errors::InvalidArgument; -void ConvertToMlirShape(const TensorShape& input_shape, - llvm::SmallVectorImpl* shape) { - shape->reserve(input_shape.dims()); - for (const auto& d : input_shape) { - shape->push_back(d.size); - } -} - -Status ConvertToMlirShape(const TensorShapeProto& input_shape, - llvm::SmallVectorImpl* shape) { - shape->reserve(input_shape.dim_size()); - auto& dims = input_shape.dim(); - for (auto& d : dims) { - if (d.size() > std::numeric_limits::max()) { - return InvalidArgument("Shape element overflows"); - } - shape->push_back(d.size()); - } - return Status::OK(); -} - static TensorProto ConvertToProto(const Tensor& input_tensor, bool use_tensor_content = true) { TensorProto tensor_proto; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index f57cd1c872a..f4b6f702cef 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -29,10 +29,6 @@ namespace tensorflow { using stream_executor::port::StatusOr; -// Converts an TensorFlow shape proto to the one used in MLIR. -Status ConvertToMlirShape(const TensorShapeProto& input_shape, - llvm::SmallVectorImpl* shape); - // Converts an TensorFlow tensor proto into an MLIR elements attribute. StatusOr ConvertTensorProto(const TensorProto& input_tensor, mlir::Builder* builder); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index a78cf12f096..e2d970c8dfd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -31,7 +31,7 @@ using mlir::Builder; using mlir::ShapedType; using mlir::Type; -Status ConvertDataType(const DataType& dtype, Builder builder, Type* type) { +Status ConvertDataType(DataType dtype, Builder builder, Type* type) { switch (dtype) { case DT_HALF: *type = builder.getF16Type(); @@ -149,4 +149,38 @@ Status ConvertToDataType(Type type, DataType* dtype) { return Status::OK(); } +void ConvertToMlirShape(const TensorShape& input_shape, + llvm::SmallVectorImpl* shape) { + shape->reserve(input_shape.dims()); + for (const auto& d : input_shape) { + shape->push_back(d.size); + } +} + +Status ConvertToMlirShape(const TensorShapeProto& input_shape, + llvm::SmallVectorImpl* shape) { + shape->reserve(input_shape.dim_size()); + auto& dims = input_shape.dim(); + for (auto& d : dims) { + if (d.size() > std::numeric_limits::max()) { + return errors::InvalidArgument("Shape element overflows"); + } + shape->push_back(d.size()); + } + return Status::OK(); +} + +StatusOr ConvertToMlirTensorType(const TensorShapeProto& shape, + DataType dtype, + mlir::Builder* builder) { + mlir::Type element_type; + TF_RETURN_IF_ERROR(ConvertDataType(dtype, *builder, &element_type)); + if (shape.unknown_rank()) { + return mlir::UnrankedTensorType::get(element_type); + } + llvm::SmallVector shape_dims; + TF_RETURN_IF_ERROR(ConvertToMlirShape(shape, &shape_dims)); + return mlir::RankedTensorType::get(shape_dims, element_type); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h index 369da7cc480..fa5c92c12fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h @@ -18,14 +18,17 @@ limitations under the License. #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { +using stream_executor::port::StatusOr; + // Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type. -Status ConvertDataType(const DataType& dtype, mlir::Builder builder, - mlir::Type* type); +Status ConvertDataType(DataType dtype, mlir::Builder builder, mlir::Type* type); // Converts a scalar MLIR type to a TensorFlow Datatype. Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype); @@ -34,6 +37,19 @@ Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype); // is converted directly. If it is a shaped type, the element type is converted. Status ConvertToDataType(mlir::Type type, DataType* dtype); +// Converts an TensorFlow shape to the one used in MLIR. +void ConvertToMlirShape(const TensorShape& input_shape, + llvm::SmallVectorImpl* shape); + +// Converts an TensorFlow shape proto to the one used in MLIR. +Status ConvertToMlirShape(const TensorShapeProto& input_shape, + llvm::SmallVectorImpl* shape); + +// Given a tensor shape and dtype, get the corresponding MLIR tensor type. +StatusOr ConvertToMlirTensorType(const TensorShapeProto& shape, + DataType dtype, + mlir::Builder* builder); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc new file mode 100644 index 00000000000..423d61dc2c6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace { + +std::string ConvertToMlirString(const std::vector& dims, + bool unknown_rank, DataType dtype) { + TensorShapeProto shape; + shape.set_unknown_rank(unknown_rank); + for (int64_t dim : dims) { + shape.add_dim()->set_size(dim); + } + mlir::MLIRContext context; + mlir::Builder b(&context); + auto status_or = ConvertToMlirTensorType(shape, dtype, &b); + std::string buf; + llvm::raw_string_ostream os(buf); + status_or.ValueOrDie().print(os); + return os.str(); +} + +TEST(MlirConvertType, ConvertToMlirTensorType) { + // Simple case of static shapes. + EXPECT_EQ("tensor<4x8x16xi32>", + ConvertToMlirString({4, 8, 16}, /*unknown_rank=*/false, + DataType::DT_INT32)); + + // Partially known shapes. + EXPECT_EQ("tensor", + ConvertToMlirString({-1, 27, -1}, /*unknown_rank=*/false, + DataType::DT_BFLOAT16)); + + // Unranked shapes. + EXPECT_EQ("tensor<*xf32>", + ConvertToMlirString({}, /*unknown_rank=*/true, DataType::DT_FLOAT)); +} + +} // namespace + +} // namespace tensorflow