tf_saved_model: Add explicit tracking of variable partial types.

Variables in TF can change type, within the bounds of the partial type that
they are given at creation time. This CL adds another field to
tf_saved_model.global_tensor to explicitly track that partial type.

For example, a variable can be created with partial shape `tensor<*xf32>`,
and then it can be assigned any compatible type (i.e. any f32 tensor) at
any point in the program.
The same applies for `tensor<?x27x?xf32>`, but the types that are legal to
assign to it are now more restricted.

This CL also refactors some of the utilities a little bit.

PiperOrigin-RevId: 281325922
Change-Id: Ifa04f827670681839574432b9d9d0ba2c0ee5f79
This commit is contained in:
Sean Silva 2019-11-19 10:34:58 -08:00 committed by TensorFlower Gardener
parent 6c092464ff
commit 9182431be6
15 changed files with 228 additions and 41 deletions

View File

@ -493,6 +493,22 @@ cc_library(
], ],
) )
tf_cc_test(
name = "convert_type_test",
size = "small",
srcs = ["utils/convert_type_test.cc"],
deps = [
":convert_type",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
],
)
cc_library( cc_library(
name = "convert_tensor", name = "convert_tensor",
srcs = ["utils/convert_tensor.cc"], srcs = ["utils/convert_tensor.cc"],

View File

@ -26,7 +26,9 @@ limitations under the License.
#include "mlir/IR/Identifier.h" // TF:local_config_mlir #include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir #include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir #include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
namespace mlir { namespace mlir {
@ -48,6 +50,22 @@ static bool IsStrArrayAttr(Attribute attr) {
// TensorFlowSavedModelDialect Op's // TensorFlowSavedModelDialect Op's
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) {
if (!t1.isa<TensorType>() || !t2.isa<TensorType>()) {
return failure();
}
return verifyCompatibleShape(t1.cast<TensorType>(), t2.cast<TensorType>());
}
static LogicalResult Verify(GlobalTensorOp global_tensor) {
if (failed(VerifyTensorTypesCompatible(
global_tensor.type(), global_tensor.value().Attribute::getType()))) {
return global_tensor.emitError() << "'type' and 'value' attributes should "
"have compatible tensor types";
}
return success();
}
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"

View File

@ -111,13 +111,23 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> {
The `value` attribute contains the tensor's value (or initial value, in the The `value` attribute contains the tensor's value (or initial value, in the
case it is mutable). case it is mutable).
The `type` attribute contains the tensor's type, which for the case of
mutable tensors might be more general than just the fixed static shape of
the `value` attribute. For example, a global tensor might be unranked such
as `tensor<*xf32>`, or a more complex shape such as `tensor<4x?x27xf32>`.
The shape of `value` must be compatible with the shape of `type` in the
sense of `tf.TensorShape` compatibility. And the element types must match.
}]; }];
let arguments = (ins let arguments = (ins
StrAttr:$sym_name, StrAttr:$sym_name,
ElementsAttr:$value, ElementsAttr:$value,
TypeAttr:$type,
UnitAttr:$is_mutable UnitAttr:$is_mutable
); );
let verifier = [{ return Verify(*this); }];
} }
#endif // SAVED_MODEL_DIALECT #endif // SAVED_MODEL_DIALECT

View File

@ -31,8 +31,8 @@ class TestModule(tf.Module):
self.v42 = tf.Variable(42.0) self.v42 = tf.Variable(42.0)
self.c43 = tf.constant(43.0) self.c43 = tf.constant(43.0)
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], value = dense<4.200000e+01> : tensor<f32>} : () -> () # CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], type = tensor<f32>, value = dense<4.200000e+01> : tensor<f32>} : () -> ()
# CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], value = dense<4.300000e+01> : tensor<f32>} : () -> () # CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], type = tensor<f32>, value = dense<4.300000e+01> : tensor<f32>} : () -> ()
# CHECK: func {{@[a-zA-Z_0-9]+}}( # CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]}, # CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @[[VAR]]}, # CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @[[VAR]]},

View File

@ -0,0 +1,38 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/partially_shaped_variables | FileCheck %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common
class TestModule(tf.Module):
def __init__(self):
super(TestModule, self).__init__()
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v0"], type = tensor<*xf32>, value = dense<0.000000e+00> : tensor<1xf32>} : () -> ()
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v1"], type = tensor<?xf32>, value = dense<0.000000e+00> : tensor<1xf32>} : () -> ()
self.v0 = tf.Variable([0.], shape=tf.TensorShape(None))
self.v1 = tf.Variable([0.], shape=[None])
if __name__ == '__main__':
common.do_test(TestModule, exported_names=[])

View File

@ -5,7 +5,7 @@ module attributes {tf_saved_model.semantics} {
// Test case: Simple case of inlining. // Test case: Simple case of inlining.
// CHECK-NOT: tf_saved_model.global_tensor // CHECK-NOT: tf_saved_model.global_tensor
"tf_saved_model.global_tensor"() { sym_name = "c", value = dense<1.0> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
// CHECK: func @f() // CHECK: func @f()
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c}) func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
@ -23,7 +23,7 @@ module attributes {tf_saved_model.semantics} {
// Test case: Do not inline mutable global tensors. // Test case: Do not inline mutable global tensors.
// CHECK: tf_saved_model.global_tensor // CHECK: tf_saved_model.global_tensor
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<1.0> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v}) // CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})

View File

@ -7,6 +7,7 @@ module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { "tf_saved_model.global_tensor"() {
tf_saved_model.exported_names = ["some_const"], tf_saved_model.exported_names = ["some_const"],
sym_name = "some_constant", sym_name = "some_constant",
type = tensor<1x64xf32>,
value = dense<42.0> : tensor<1x64xf32> value = dense<42.0> : tensor<1x64xf32>
} : () -> () } : () -> ()
@ -16,6 +17,7 @@ module attributes {tf_saved_model.semantics} {
is_mutable, is_mutable,
tf_saved_model.exported_names = ["some_var", "some.other.name"], tf_saved_model.exported_names = ["some_var", "some.other.name"],
sym_name = "some_variable", sym_name = "some_variable",
type = tensor<?x64xf32>,
value = dense<42.0> : tensor<1x64xf32> value = dense<42.0> : tensor<1x64xf32>
} : () -> () } : () -> ()

View File

@ -133,7 +133,7 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} { module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { sym_name = "some_constant", value = dense<42.0> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { sym_name = "some_constant", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
// expected-error@+1 {{all 'tf_saved_model.index_path' arg attributes should precede all 'tf_saved_model.bound_input' arg attributes}} // expected-error@+1 {{all 'tf_saved_model.index_path' arg attributes should precede all 'tf_saved_model.bound_input' arg attributes}}
func @f( func @f(
@ -205,3 +205,10 @@ module attributes {tf_saved_model.semantics} {
} }
} }
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{'type' and 'value' attributes should have compatible tensor types}}
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v0", type = tensor<3xf32>, value = dense<42.0> : tensor<9xf32> } : () -> ()
}

View File

@ -11,7 +11,7 @@ module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() { // CHECK: "tf_saved_model.global_tensor"() {
// CHECK-NOT: is_mutable // CHECK-NOT: is_mutable
// CHECK-SAME: } : () -> () // CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<42.> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v}) // CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []}) func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
@ -33,7 +33,7 @@ module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() { // CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable // CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> () // CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<42.> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) // CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
@ -55,7 +55,7 @@ module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() { // CHECK: "tf_saved_model.global_tensor"() {
// CHECK: is_mutable // CHECK: is_mutable
// CHECK-SAME: } : () -> () // CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], value = dense<42.> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) // CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []}) func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
@ -88,7 +88,7 @@ module attributes {tf_saved_model.semantics} {
// Test case: Check that an immutable bound input isn't modified. // Test case: Check that an immutable bound input isn't modified.
"tf_saved_model.global_tensor"() { sym_name = "c", value = dense<42.> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c}) // CHECK: func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c}) -> (tensor<f32> {tf_saved_model.index_path = []}) func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c}) -> (tensor<f32> {tf_saved_model.index_path = []})
@ -109,13 +109,13 @@ module attributes {tf_saved_model.semantics} {
// Test case: Check that an exported global tensor that isn't bound to an // Test case: Check that an exported global tensor that isn't bound to an
// argument is not erased. // argument is not erased.
"tf_saved_model.global_tensor"() { sym_name = "exported_unbound", tf_saved_model.exported_names = ["exported_unbound"], value = dense<42.> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { sym_name = "exported_unbound", tf_saved_model.exported_names = ["exported_unbound"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: sym_name = "exported_unbound" // CHECK: sym_name = "exported_unbound"
// Test case: Check that a global tensor that isn't even bound to an argument // Test case: Check that a global tensor that isn't even bound to an argument
// is erased. // is erased.
"tf_saved_model.global_tensor"() { sym_name = "unexported_unbound", value = dense<42.> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { sym_name = "unexported_unbound", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK-NOT: sym_name = "unexported_unbound" // CHECK-NOT: sym_name = "unexported_unbound"
} }
@ -131,7 +131,7 @@ module attributes {tf_saved_model.semantics} {
// We erase the argument that this global tensor is bound to, so we delete // We erase the argument that this global tensor is bound to, so we delete
// the global tensor too. // the global tensor too.
// CHECK-NOT: tf_saved_model.global_tensor // CHECK-NOT: tf_saved_model.global_tensor
"tf_saved_model.global_tensor"() { sym_name = "c", value = dense<42.> : tensor<f32> } : () -> () "tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @f() // CHECK: func @f()
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c}) func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})

View File

@ -84,6 +84,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h" #include "tensorflow/core/protobuf/struct.pb.h"
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
static inline absl::string_view StringRefToView(llvm::StringRef ref) { static inline absl::string_view StringRefToView(llvm::StringRef ref) {
return {ref.data(), ref.size()}; return {ref.data(), ref.size()};
@ -2341,11 +2342,16 @@ Status CreateSavedModelIR(
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
Tensor value, ReadVariableFromSession(saved_model, variable.name())); Tensor value, ReadVariableFromSession(saved_model, variable.name()));
TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
// A variable can have a partially known type, such as tensor<?x27x?xf32>,
// even if the initializer is a specific static shape.
TF_ASSIGN_OR_RETURN(
auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(),
&builder));
auto op = builder.create<mlir::tf_saved_model::GlobalTensorOp>( auto op = builder.create<mlir::tf_saved_model::GlobalTensorOp>(
builder.getUnknownLoc(), builder.getUnknownLoc(),
builder.getStringAttr(object_names.GetSymbolTableName(node_id)), builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
value_attr, value_attr,
/*type=*/mlir::TypeAttr::get(type),
/*is_mutable=*/builder.getUnitAttr()); /*is_mutable=*/builder.getUnitAttr());
op.setAttr( op.setAttr(
"tf_saved_model.exported_names", "tf_saved_model.exported_names",
@ -2360,6 +2366,7 @@ Status CreateSavedModelIR(
builder.getUnknownLoc(), builder.getUnknownLoc(),
builder.getStringAttr(object_names.GetSymbolTableName(node_id)), builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
value_attr, value_attr,
/*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()),
/*is_mutable=*/nullptr); /*is_mutable=*/nullptr);
op.setAttr( op.setAttr(
"tf_saved_model.exported_names", "tf_saved_model.exported_names",

View File

@ -52,27 +52,6 @@ using mlir::ShapedType;
using mlir::Type; using mlir::Type;
using tensorflow::errors::InvalidArgument; using tensorflow::errors::InvalidArgument;
void ConvertToMlirShape(const TensorShape& input_shape,
llvm::SmallVectorImpl<int64_t>* shape) {
shape->reserve(input_shape.dims());
for (const auto& d : input_shape) {
shape->push_back(d.size);
}
}
Status ConvertToMlirShape(const TensorShapeProto& input_shape,
llvm::SmallVectorImpl<int64_t>* shape) {
shape->reserve(input_shape.dim_size());
auto& dims = input_shape.dim();
for (auto& d : dims) {
if (d.size() > std::numeric_limits<int64_t>::max()) {
return InvalidArgument("Shape element overflows");
}
shape->push_back(d.size());
}
return Status::OK();
}
static TensorProto ConvertToProto(const Tensor& input_tensor, static TensorProto ConvertToProto(const Tensor& input_tensor,
bool use_tensor_content = true) { bool use_tensor_content = true) {
TensorProto tensor_proto; TensorProto tensor_proto;

View File

@ -29,10 +29,6 @@ namespace tensorflow {
using stream_executor::port::StatusOr; using stream_executor::port::StatusOr;
// Converts an TensorFlow shape proto to the one used in MLIR.
Status ConvertToMlirShape(const TensorShapeProto& input_shape,
llvm::SmallVectorImpl<int64_t>* shape);
// Converts an TensorFlow tensor proto into an MLIR elements attribute. // Converts an TensorFlow tensor proto into an MLIR elements attribute.
StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor, StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor,
mlir::Builder* builder); mlir::Builder* builder);

View File

@ -31,7 +31,7 @@ using mlir::Builder;
using mlir::ShapedType; using mlir::ShapedType;
using mlir::Type; using mlir::Type;
Status ConvertDataType(const DataType& dtype, Builder builder, Type* type) { Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
switch (dtype) { switch (dtype) {
case DT_HALF: case DT_HALF:
*type = builder.getF16Type(); *type = builder.getF16Type();
@ -149,4 +149,38 @@ Status ConvertToDataType(Type type, DataType* dtype) {
return Status::OK(); return Status::OK();
} }
void ConvertToMlirShape(const TensorShape& input_shape,
llvm::SmallVectorImpl<int64_t>* shape) {
shape->reserve(input_shape.dims());
for (const auto& d : input_shape) {
shape->push_back(d.size);
}
}
Status ConvertToMlirShape(const TensorShapeProto& input_shape,
llvm::SmallVectorImpl<int64_t>* shape) {
shape->reserve(input_shape.dim_size());
auto& dims = input_shape.dim();
for (auto& d : dims) {
if (d.size() > std::numeric_limits<int64_t>::max()) {
return errors::InvalidArgument("Shape element overflows");
}
shape->push_back(d.size());
}
return Status::OK();
}
StatusOr<mlir::Type> ConvertToMlirTensorType(const TensorShapeProto& shape,
DataType dtype,
mlir::Builder* builder) {
mlir::Type element_type;
TF_RETURN_IF_ERROR(ConvertDataType(dtype, *builder, &element_type));
if (shape.unknown_rank()) {
return mlir::UnrankedTensorType::get(element_type);
}
llvm::SmallVector<int64_t, 4> shape_dims;
TF_RETURN_IF_ERROR(ConvertToMlirShape(shape, &shape_dims));
return mlir::RankedTensorType::get(shape_dims, element_type);
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,14 +18,17 @@ limitations under the License.
#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow { namespace tensorflow {
using stream_executor::port::StatusOr;
// Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type. // Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type.
Status ConvertDataType(const DataType& dtype, mlir::Builder builder, Status ConvertDataType(DataType dtype, mlir::Builder builder, mlir::Type* type);
mlir::Type* type);
// Converts a scalar MLIR type to a TensorFlow Datatype. // Converts a scalar MLIR type to a TensorFlow Datatype.
Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype); Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype);
@ -34,6 +37,19 @@ Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype);
// is converted directly. If it is a shaped type, the element type is converted. // is converted directly. If it is a shaped type, the element type is converted.
Status ConvertToDataType(mlir::Type type, DataType* dtype); Status ConvertToDataType(mlir::Type type, DataType* dtype);
// Converts an TensorFlow shape to the one used in MLIR.
void ConvertToMlirShape(const TensorShape& input_shape,
llvm::SmallVectorImpl<int64_t>* shape);
// Converts an TensorFlow shape proto to the one used in MLIR.
Status ConvertToMlirShape(const TensorShapeProto& input_shape,
llvm::SmallVectorImpl<int64_t>* shape);
// Given a tensor shape and dtype, get the corresponding MLIR tensor type.
StatusOr<mlir::Type> ConvertToMlirTensorType(const TensorShapeProto& shape,
DataType dtype,
mlir::Builder* builder);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_

View File

@ -0,0 +1,64 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace {
std::string ConvertToMlirString(const std::vector<int64_t>& dims,
bool unknown_rank, DataType dtype) {
TensorShapeProto shape;
shape.set_unknown_rank(unknown_rank);
for (int64_t dim : dims) {
shape.add_dim()->set_size(dim);
}
mlir::MLIRContext context;
mlir::Builder b(&context);
auto status_or = ConvertToMlirTensorType(shape, dtype, &b);
std::string buf;
llvm::raw_string_ostream os(buf);
status_or.ValueOrDie().print(os);
return os.str();
}
TEST(MlirConvertType, ConvertToMlirTensorType) {
// Simple case of static shapes.
EXPECT_EQ("tensor<4x8x16xi32>",
ConvertToMlirString({4, 8, 16}, /*unknown_rank=*/false,
DataType::DT_INT32));
// Partially known shapes.
EXPECT_EQ("tensor<?x27x?xbf16>",
ConvertToMlirString({-1, 27, -1}, /*unknown_rank=*/false,
DataType::DT_BFLOAT16));
// Unranked shapes.
EXPECT_EQ("tensor<*xf32>",
ConvertToMlirString({}, /*unknown_rank=*/true, DataType::DT_FLOAT));
}
} // namespace
} // namespace tensorflow