Add ConvertTypeToTensorShape method.
PiperOrigin-RevId: 266999034
This commit is contained in:
parent
53ed37363c
commit
b3f1b34934
@ -386,6 +386,21 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "convert_tensor_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["utils/convert_tensor_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":convert_tensor",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"@local_config_mlir//:IR",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "mangling_util",
|
name = "mangling_util",
|
||||||
srcs = ["utils/mangling_util.cc"],
|
srcs = ["utils/mangling_util.cc"],
|
||||||
|
@ -147,6 +147,23 @@ void ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) {
|
||||||
|
if (type.isa<mlir::UnrankedTensorType>()) {
|
||||||
|
// An empty PartialTensorShape indicates an unranked tensor.
|
||||||
|
return PartialTensorShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
|
||||||
|
TensorShapeProto tensor_shape_proto;
|
||||||
|
ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto);
|
||||||
|
return PartialTensorShape(tensor_shape_proto);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
|
||||||
|
// Empty TensorShape indicates a scalar.
|
||||||
|
return TensorShape();
|
||||||
|
}
|
||||||
|
|
||||||
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
||||||
Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
|
Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
|
||||||
TensorProto* output_tensor) {
|
TensorProto* output_tensor) {
|
||||||
|
@ -45,6 +45,9 @@ StatusOr<mlir::ElementsAttr> ConvertTensor(const Tensor& input_tensor,
|
|||||||
void ConvertToTensorShapeProto(llvm::ArrayRef<int64_t> shape,
|
void ConvertToTensorShapeProto(llvm::ArrayRef<int64_t> shape,
|
||||||
TensorShapeProto* output_shape);
|
TensorShapeProto* output_shape);
|
||||||
|
|
||||||
|
// Converts an MLIR type with static tensor shape to an TensorFlow tensor shape.
|
||||||
|
PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type);
|
||||||
|
|
||||||
// Converts an MLIR elements attribute to an TensorFlow tensor proto.
|
// Converts an MLIR elements attribute to an TensorFlow tensor proto.
|
||||||
Status ConvertToTensorProto(mlir::ElementsAttr attr,
|
Status ConvertToTensorProto(mlir::ElementsAttr attr,
|
||||||
TensorProto* output_tensor);
|
TensorProto* output_tensor);
|
||||||
|
@ -0,0 +1,65 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::Builder b(&context);
|
||||||
|
|
||||||
|
PartialTensorShape output_shape =
|
||||||
|
ConvertTypeToTensorShape(b.getTensorType(b.getF32Type()));
|
||||||
|
EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::Builder b(&context);
|
||||||
|
|
||||||
|
PartialTensorShape output_shape =
|
||||||
|
ConvertTypeToTensorShape(b.getTensorType({-1, 2, 3}, b.getF32Type()));
|
||||||
|
EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape({-1, 2, 3})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) {
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::Builder b(&context);
|
||||||
|
|
||||||
|
PartialTensorShape output_shape =
|
||||||
|
ConvertTypeToTensorShape(b.getTensorType({1, 2, 3}, b.getF32Type()));
|
||||||
|
EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape({1, 2, 3})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConvertTypeToTensorTypeTest, ScalarTensorType) {
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::Builder b(&context);
|
||||||
|
|
||||||
|
PartialTensorShape output_shape = ConvertTypeToTensorShape(b.getF32Type());
|
||||||
|
EXPECT_TRUE(output_shape.IsIdenticalTo(TensorShape()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user