Add ConvertTypeToTensorShape method.
PiperOrigin-RevId: 266999034
This commit is contained in:
parent
53ed37363c
commit
b3f1b34934
@ -386,6 +386,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "convert_tensor_test",
|
||||
size = "small",
|
||||
srcs = ["utils/convert_tensor_test.cc"],
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mangling_util",
|
||||
srcs = ["utils/mangling_util.cc"],
|
||||
|
@ -147,6 +147,23 @@ void ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
|
||||
}
|
||||
}
|
||||
|
||||
PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) {
|
||||
if (type.isa<mlir::UnrankedTensorType>()) {
|
||||
// An empty PartialTensorShape indicates an unranked tensor.
|
||||
return PartialTensorShape();
|
||||
}
|
||||
|
||||
if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
|
||||
TensorShapeProto tensor_shape_proto;
|
||||
ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto);
|
||||
return PartialTensorShape(tensor_shape_proto);
|
||||
}
|
||||
|
||||
// If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
|
||||
// Empty TensorShape indicates a scalar.
|
||||
return TensorShape();
|
||||
}
|
||||
|
||||
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
||||
Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
|
@ -45,6 +45,9 @@ StatusOr<mlir::ElementsAttr> ConvertTensor(const Tensor& input_tensor,
|
||||
void ConvertToTensorShapeProto(llvm::ArrayRef<int64_t> shape,
|
||||
TensorShapeProto* output_shape);
|
||||
|
||||
// Converts an MLIR type with static tensor shape to an TensorFlow tensor shape.
|
||||
PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type);
|
||||
|
||||
// Converts an MLIR elements attribute to an TensorFlow tensor proto.
|
||||
Status ConvertToTensorProto(mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor);
|
||||
|
@ -0,0 +1,65 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder b(&context);
|
||||
|
||||
PartialTensorShape output_shape =
|
||||
ConvertTypeToTensorShape(b.getTensorType(b.getF32Type()));
|
||||
EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape()));
|
||||
}
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder b(&context);
|
||||
|
||||
PartialTensorShape output_shape =
|
||||
ConvertTypeToTensorShape(b.getTensorType({-1, 2, 3}, b.getF32Type()));
|
||||
EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape({-1, 2, 3})));
|
||||
}
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder b(&context);
|
||||
|
||||
PartialTensorShape output_shape =
|
||||
ConvertTypeToTensorShape(b.getTensorType({1, 2, 3}, b.getF32Type()));
|
||||
EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape({1, 2, 3})));
|
||||
}
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, ScalarTensorType) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder b(&context);
|
||||
|
||||
PartialTensorShape output_shape = ConvertTypeToTensorShape(b.getF32Type());
|
||||
EXPECT_TRUE(output_shape.IsIdenticalTo(TensorShape()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user