tf_saved_model: Add explicit tracking of variable partial types.
Variables in TF can change type, within the bounds of the partial type that they are given at creation time. This CL adds another field to tf_saved_model.global_tensor to explicitly track that partial type. For example, a variable can be created with partial shape `tensor<*xf32>`, and then it can be assigned any compatible type (i.e. any f32 tensor) at any point in the program. The same applies for `tensor<?x27x?xf32>`, but the types that are legal to assign to it are now more restricted. This CL also refactors some of the utilities a little bit. PiperOrigin-RevId: 281325922 Change-Id: Ifa04f827670681839574432b9d9d0ba2c0ee5f79
This commit is contained in:
parent
6c092464ff
commit
9182431be6
|
@ -493,6 +493,22 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "convert_type_test",
|
||||
size = "small",
|
||||
srcs = ["utils/convert_type_test.cc"],
|
||||
deps = [
|
||||
":convert_type",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_tensor",
|
||||
srcs = ["utils/convert_tensor.cc"],
|
||||
|
|
|
@ -26,7 +26,9 @@ limitations under the License.
|
|||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
|
@ -48,6 +50,22 @@ static bool IsStrArrayAttr(Attribute attr) {
|
|||
// TensorFlowSavedModelDialect Op's
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) {
|
||||
if (!t1.isa<TensorType>() || !t2.isa<TensorType>()) {
|
||||
return failure();
|
||||
}
|
||||
return verifyCompatibleShape(t1.cast<TensorType>(), t2.cast<TensorType>());
|
||||
}
|
||||
|
||||
static LogicalResult Verify(GlobalTensorOp global_tensor) {
|
||||
if (failed(VerifyTensorTypesCompatible(
|
||||
global_tensor.type(), global_tensor.value().Attribute::getType()))) {
|
||||
return global_tensor.emitError() << "'type' and 'value' attributes should "
|
||||
"have compatible tensor types";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
|
||||
|
||||
|
|
|
@ -111,13 +111,23 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> {
|
|||
|
||||
The `value` attribute contains the tensor's value (or initial value, in the
|
||||
case it is mutable).
|
||||
|
||||
The `type` attribute contains the tensor's type, which for the case of
|
||||
mutable tensors might be more general than just the fixed static shape of
|
||||
the `value` attribute. For example, a global tensor might be unranked such
|
||||
as `tensor<*xf32>`, or a more complex shape such as `tensor<4x?x27xf32>`.
|
||||
The shape of `value` must be compatible with the shape of `type` in the
|
||||
sense of `tf.TensorShape` compatibility. And the element types must match.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$sym_name,
|
||||
ElementsAttr:$value,
|
||||
TypeAttr:$type,
|
||||
UnitAttr:$is_mutable
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
#endif // SAVED_MODEL_DIALECT
|
||||
|
|
|
@ -31,8 +31,8 @@ class TestModule(tf.Module):
|
|||
self.v42 = tf.Variable(42.0)
|
||||
self.c43 = tf.constant(43.0)
|
||||
|
||||
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], value = dense<4.200000e+01> : tensor<f32>} : () -> ()
|
||||
# CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], value = dense<4.300000e+01> : tensor<f32>} : () -> ()
|
||||
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], type = tensor<f32>, value = dense<4.200000e+01> : tensor<f32>} : () -> ()
|
||||
# CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], type = tensor<f32>, value = dense<4.300000e+01> : tensor<f32>} : () -> ()
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @[[VAR]]},
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# RUN: %p/partially_shaped_variables | FileCheck %s
|
||||
|
||||
# pylint: disable=missing-docstring,line-too-long
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common
|
||||
|
||||
|
||||
class TestModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModule, self).__init__()
|
||||
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v0"], type = tensor<*xf32>, value = dense<0.000000e+00> : tensor<1xf32>} : () -> ()
|
||||
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v1"], type = tensor<?xf32>, value = dense<0.000000e+00> : tensor<1xf32>} : () -> ()
|
||||
self.v0 = tf.Variable([0.], shape=tf.TensorShape(None))
|
||||
self.v1 = tf.Variable([0.], shape=[None])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
common.do_test(TestModule, exported_names=[])
|
|
@ -5,7 +5,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
// Test case: Simple case of inlining.
|
||||
|
||||
// CHECK-NOT: tf_saved_model.global_tensor
|
||||
"tf_saved_model.global_tensor"() { sym_name = "c", value = dense<1.0> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f()
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
|
||||
|
@ -23,7 +23,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
// Test case: Do not inline mutable global tensors.
|
||||
|
||||
// CHECK: tf_saved_model.global_tensor
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<1.0> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
||||
|
|
|
@ -7,6 +7,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
"tf_saved_model.global_tensor"() {
|
||||
tf_saved_model.exported_names = ["some_const"],
|
||||
sym_name = "some_constant",
|
||||
type = tensor<1x64xf32>,
|
||||
value = dense<42.0> : tensor<1x64xf32>
|
||||
} : () -> ()
|
||||
|
||||
|
@ -16,6 +17,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
is_mutable,
|
||||
tf_saved_model.exported_names = ["some_var", "some.other.name"],
|
||||
sym_name = "some_variable",
|
||||
type = tensor<?x64xf32>,
|
||||
value = dense<42.0> : tensor<1x64xf32>
|
||||
} : () -> ()
|
||||
|
||||
|
|
|
@ -133,7 +133,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.global_tensor"() { sym_name = "some_constant", value = dense<42.0> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { sym_name = "some_constant", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
|
||||
|
||||
// expected-error@+1 {{all 'tf_saved_model.index_path' arg attributes should precede all 'tf_saved_model.bound_input' arg attributes}}
|
||||
func @f(
|
||||
|
@ -205,3 +205,10 @@ module attributes {tf_saved_model.semantics} {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
// expected-error@+1 {{'type' and 'value' attributes should have compatible tensor types}}
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v0", type = tensor<3xf32>, value = dense<42.0> : tensor<9xf32> } : () -> ()
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-NOT: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<42.> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
|
@ -33,7 +33,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", value = dense<42.> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
|
@ -55,7 +55,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], value = dense<42.> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
|
@ -88,7 +88,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
|
||||
// Test case: Check that an immutable bound input isn't modified.
|
||||
|
||||
"tf_saved_model.global_tensor"() { sym_name = "c", value = dense<42.> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
|
||||
func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
|
@ -109,13 +109,13 @@ module attributes {tf_saved_model.semantics} {
|
|||
// Test case: Check that an exported global tensor that isn't bound to an
|
||||
// argument is not erased.
|
||||
|
||||
"tf_saved_model.global_tensor"() { sym_name = "exported_unbound", tf_saved_model.exported_names = ["exported_unbound"], value = dense<42.> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { sym_name = "exported_unbound", tf_saved_model.exported_names = ["exported_unbound"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
// CHECK: sym_name = "exported_unbound"
|
||||
|
||||
// Test case: Check that a global tensor that isn't even bound to an argument
|
||||
// is erased.
|
||||
|
||||
"tf_saved_model.global_tensor"() { sym_name = "unexported_unbound", value = dense<42.> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { sym_name = "unexported_unbound", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
// CHECK-NOT: sym_name = "unexported_unbound"
|
||||
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
// We erase the argument that this global tensor is bound to, so we delete
|
||||
// the global tensor too.
|
||||
// CHECK-NOT: tf_saved_model.global_tensor
|
||||
"tf_saved_model.global_tensor"() { sym_name = "c", value = dense<42.> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f()
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
|
||||
|
|
|
@ -84,6 +84,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
return {ref.data(), ref.size()};
|
||||
|
@ -2341,11 +2342,16 @@ Status CreateSavedModelIR(
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
Tensor value, ReadVariableFromSession(saved_model, variable.name()));
|
||||
TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
|
||||
|
||||
// A variable can have a partially known type, such as tensor<?x27x?xf32>,
|
||||
// even if the initializer is a specific static shape.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(),
|
||||
&builder));
|
||||
auto op = builder.create<mlir::tf_saved_model::GlobalTensorOp>(
|
||||
builder.getUnknownLoc(),
|
||||
builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
|
||||
value_attr,
|
||||
/*type=*/mlir::TypeAttr::get(type),
|
||||
/*is_mutable=*/builder.getUnitAttr());
|
||||
op.setAttr(
|
||||
"tf_saved_model.exported_names",
|
||||
|
@ -2360,6 +2366,7 @@ Status CreateSavedModelIR(
|
|||
builder.getUnknownLoc(),
|
||||
builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
|
||||
value_attr,
|
||||
/*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()),
|
||||
/*is_mutable=*/nullptr);
|
||||
op.setAttr(
|
||||
"tf_saved_model.exported_names",
|
||||
|
|
|
@ -52,27 +52,6 @@ using mlir::ShapedType;
|
|||
using mlir::Type;
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
|
||||
void ConvertToMlirShape(const TensorShape& input_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* shape) {
|
||||
shape->reserve(input_shape.dims());
|
||||
for (const auto& d : input_shape) {
|
||||
shape->push_back(d.size);
|
||||
}
|
||||
}
|
||||
|
||||
Status ConvertToMlirShape(const TensorShapeProto& input_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* shape) {
|
||||
shape->reserve(input_shape.dim_size());
|
||||
auto& dims = input_shape.dim();
|
||||
for (auto& d : dims) {
|
||||
if (d.size() > std::numeric_limits<int64_t>::max()) {
|
||||
return InvalidArgument("Shape element overflows");
|
||||
}
|
||||
shape->push_back(d.size());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static TensorProto ConvertToProto(const Tensor& input_tensor,
|
||||
bool use_tensor_content = true) {
|
||||
TensorProto tensor_proto;
|
||||
|
|
|
@ -29,10 +29,6 @@ namespace tensorflow {
|
|||
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// Converts an TensorFlow shape proto to the one used in MLIR.
|
||||
Status ConvertToMlirShape(const TensorShapeProto& input_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* shape);
|
||||
|
||||
// Converts an TensorFlow tensor proto into an MLIR elements attribute.
|
||||
StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor,
|
||||
mlir::Builder* builder);
|
||||
|
|
|
@ -31,7 +31,7 @@ using mlir::Builder;
|
|||
using mlir::ShapedType;
|
||||
using mlir::Type;
|
||||
|
||||
Status ConvertDataType(const DataType& dtype, Builder builder, Type* type) {
|
||||
Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
|
||||
switch (dtype) {
|
||||
case DT_HALF:
|
||||
*type = builder.getF16Type();
|
||||
|
@ -149,4 +149,38 @@ Status ConvertToDataType(Type type, DataType* dtype) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void ConvertToMlirShape(const TensorShape& input_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* shape) {
|
||||
shape->reserve(input_shape.dims());
|
||||
for (const auto& d : input_shape) {
|
||||
shape->push_back(d.size);
|
||||
}
|
||||
}
|
||||
|
||||
Status ConvertToMlirShape(const TensorShapeProto& input_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* shape) {
|
||||
shape->reserve(input_shape.dim_size());
|
||||
auto& dims = input_shape.dim();
|
||||
for (auto& d : dims) {
|
||||
if (d.size() > std::numeric_limits<int64_t>::max()) {
|
||||
return errors::InvalidArgument("Shape element overflows");
|
||||
}
|
||||
shape->push_back(d.size());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<mlir::Type> ConvertToMlirTensorType(const TensorShapeProto& shape,
|
||||
DataType dtype,
|
||||
mlir::Builder* builder) {
|
||||
mlir::Type element_type;
|
||||
TF_RETURN_IF_ERROR(ConvertDataType(dtype, *builder, &element_type));
|
||||
if (shape.unknown_rank()) {
|
||||
return mlir::UnrankedTensorType::get(element_type);
|
||||
}
|
||||
llvm::SmallVector<int64_t, 4> shape_dims;
|
||||
TF_RETURN_IF_ERROR(ConvertToMlirShape(shape, &shape_dims));
|
||||
return mlir::RankedTensorType::get(shape_dims, element_type);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -18,14 +18,17 @@ limitations under the License.
|
|||
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type.
|
||||
Status ConvertDataType(const DataType& dtype, mlir::Builder builder,
|
||||
mlir::Type* type);
|
||||
Status ConvertDataType(DataType dtype, mlir::Builder builder, mlir::Type* type);
|
||||
|
||||
// Converts a scalar MLIR type to a TensorFlow Datatype.
|
||||
Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype);
|
||||
|
@ -34,6 +37,19 @@ Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype);
|
|||
// is converted directly. If it is a shaped type, the element type is converted.
|
||||
Status ConvertToDataType(mlir::Type type, DataType* dtype);
|
||||
|
||||
// Converts an TensorFlow shape to the one used in MLIR.
|
||||
void ConvertToMlirShape(const TensorShape& input_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* shape);
|
||||
|
||||
// Converts an TensorFlow shape proto to the one used in MLIR.
|
||||
Status ConvertToMlirShape(const TensorShapeProto& input_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* shape);
|
||||
|
||||
// Given a tensor shape and dtype, get the corresponding MLIR tensor type.
|
||||
StatusOr<mlir::Type> ConvertToMlirTensorType(const TensorShapeProto& shape,
|
||||
DataType dtype,
|
||||
mlir::Builder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
std::string ConvertToMlirString(const std::vector<int64_t>& dims,
|
||||
bool unknown_rank, DataType dtype) {
|
||||
TensorShapeProto shape;
|
||||
shape.set_unknown_rank(unknown_rank);
|
||||
for (int64_t dim : dims) {
|
||||
shape.add_dim()->set_size(dim);
|
||||
}
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder b(&context);
|
||||
auto status_or = ConvertToMlirTensorType(shape, dtype, &b);
|
||||
std::string buf;
|
||||
llvm::raw_string_ostream os(buf);
|
||||
status_or.ValueOrDie().print(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
TEST(MlirConvertType, ConvertToMlirTensorType) {
|
||||
// Simple case of static shapes.
|
||||
EXPECT_EQ("tensor<4x8x16xi32>",
|
||||
ConvertToMlirString({4, 8, 16}, /*unknown_rank=*/false,
|
||||
DataType::DT_INT32));
|
||||
|
||||
// Partially known shapes.
|
||||
EXPECT_EQ("tensor<?x27x?xbf16>",
|
||||
ConvertToMlirString({-1, 27, -1}, /*unknown_rank=*/false,
|
||||
DataType::DT_BFLOAT16));
|
||||
|
||||
// Unranked shapes.
|
||||
EXPECT_EQ("tensor<*xf32>",
|
||||
ConvertToMlirString({}, /*unknown_rank=*/true, DataType::DT_FLOAT));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue