Change Shape parsing from regexp matcher to parser.
Previously in the HLO parser/lexer shapes were tokens which were identified using a complicated regular expression. This made augmenting the textual form of shape difficult such as would be necessary for dynamic shapes or tiling. To avoid ambiguity and other problems a couple changes were made to HLO textual form, as well as some related clean up: (1) Do not redundantly print the shape inside of the constant HLO instruction's "operand" field. Previously, constant instructions we printed like: S32[2,2] constant(S32[2,2] {{1,2},{3,4}}) Now this is printed as: S32[2,2] constant({{1,2},{3,4}}) This avoids an ambiguity where the values of the literal can be misinterpreted as a layout. Also, the shape was printed inconsistently: only when the rank was greater than one. (2) Remove ShapeUtil::ParseShapeString, replace with ParseShape function in hlo parser. (3) Merge hlo_token.h into hlo_lexer.h. It is only used by the lexer and parser which include that file and avoids potential confusion with the token HLO type (4) Fix b/112302613 by removing the unused Shape field in the sharding attribute of HLO text. (5) As part of this change primitive element types are now keywords which simplifies parsing. The fallout is that a bunch of values in HLO text named "token" had to be renamed. Also, change the HLO name sanitizer to avoid these primitive type keywords. PiperOrigin-RevId: 225546437
This commit is contained in:
parent
13187e1566
commit
b55e7a9a82
@ -91,7 +91,7 @@ TEST(ConvertGraphDefToXla, Sum) {
|
||||
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
|
||||
TF_EXPECT_OK(result_or.status());
|
||||
xla::Literal result = std::move(result_or.ValueOrDie());
|
||||
EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
|
||||
EXPECT_EQ("(\ns32[] 42\n)", result.ToString());
|
||||
|
||||
config.mutable_feed(0)->mutable_id()->set_output_index(
|
||||
123); /* invalid output_index */
|
||||
|
@ -292,6 +292,22 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "primitive_util_test",
|
||||
srcs = ["primitive_util_test.cc"],
|
||||
deps = [
|
||||
":shape_util",
|
||||
":status_macros",
|
||||
":test",
|
||||
":test_helpers",
|
||||
":types",
|
||||
":util",
|
||||
":xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "layout_util_test",
|
||||
srcs = ["layout_util_test.cc"],
|
||||
@ -593,6 +609,7 @@ cc_library(
|
||||
":types",
|
||||
":util",
|
||||
":xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -1028,20 +1028,21 @@ string ShapeToString(bool print_layout, const Shape& shape) {
|
||||
}
|
||||
|
||||
void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
|
||||
bool print_layout, std::vector<string>* pieces);
|
||||
bool print_shape, bool print_layout,
|
||||
std::vector<string>* pieces);
|
||||
|
||||
void TupleToStringHelper(const LiteralBase& literal,
|
||||
const ShapeIndex& shape_index, bool print_layout,
|
||||
std::vector<string>* pieces) {
|
||||
const ShapeIndex& shape_index, bool print_shape,
|
||||
bool print_layout, std::vector<string>* pieces) {
|
||||
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
|
||||
pieces->push_back(ShapeToString(print_layout, subshape));
|
||||
pieces->push_back(" (\n");
|
||||
pieces->push_back("(\n");
|
||||
std::vector<string> tuple_pieces;
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
|
||||
ShapeIndex element_index = shape_index;
|
||||
element_index.push_back(i);
|
||||
std::vector<string> element_pieces;
|
||||
ToStringHelper(literal, element_index, print_layout, &element_pieces);
|
||||
ToStringHelper(literal, element_index, print_shape, print_layout,
|
||||
&element_pieces);
|
||||
tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
|
||||
}
|
||||
pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
|
||||
@ -1049,9 +1050,11 @@ void TupleToStringHelper(const LiteralBase& literal,
|
||||
}
|
||||
|
||||
void SparseArrayToStringHelper(const LiteralBase& literal,
|
||||
const Shape& subshape, bool print_layout,
|
||||
std::vector<string>* pieces) {
|
||||
const Shape& subshape, bool print_shape,
|
||||
bool print_layout, std::vector<string>* pieces) {
|
||||
if (print_shape) {
|
||||
pieces->push_back(ShapeToString(print_layout, subshape));
|
||||
}
|
||||
pieces->push_back("{");
|
||||
int64 rank = ShapeUtil::Rank(subshape);
|
||||
int64 num_elements = literal.sparse_element_count();
|
||||
@ -1073,8 +1076,8 @@ void SparseArrayToStringHelper(const LiteralBase& literal,
|
||||
}
|
||||
|
||||
void DenseArrayToStringHelper(const LiteralBase& literal,
|
||||
const ShapeIndex& shape_index, bool print_layout,
|
||||
std::vector<string>* pieces) {
|
||||
const ShapeIndex& shape_index, bool print_shape,
|
||||
bool print_layout, std::vector<string>* pieces) {
|
||||
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
|
||||
int64 rank = ShapeUtil::Rank(subshape);
|
||||
|
||||
@ -1135,7 +1138,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal,
|
||||
}
|
||||
};
|
||||
|
||||
if (rank > 1) {
|
||||
if (print_shape) {
|
||||
pieces->push_back(ShapeToString(print_layout, subshape));
|
||||
pieces->push_back(" ");
|
||||
}
|
||||
@ -1146,19 +1149,23 @@ void DenseArrayToStringHelper(const LiteralBase& literal,
|
||||
}
|
||||
|
||||
void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
|
||||
bool print_layout, std::vector<string>* pieces) {
|
||||
bool print_shape, bool print_layout,
|
||||
std::vector<string>* pieces) {
|
||||
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
|
||||
CHECK(LayoutUtil::HasLayout(literal.shape()));
|
||||
CHECK(LayoutUtil::HasLayout(subshape));
|
||||
if (ShapeUtil::IsTuple(subshape)) {
|
||||
TupleToStringHelper(literal, shape_index, print_layout, pieces);
|
||||
TupleToStringHelper(literal, shape_index, print_shape, print_layout,
|
||||
pieces);
|
||||
} else if (ShapeUtil::IsToken(subshape)) {
|
||||
pieces->push_back("token");
|
||||
} else if (LayoutUtil::IsSparseArray(subshape)) {
|
||||
SparseArrayToStringHelper(literal, subshape, print_layout, pieces);
|
||||
SparseArrayToStringHelper(literal, subshape, print_shape, print_layout,
|
||||
pieces);
|
||||
} else {
|
||||
CHECK(LayoutUtil::IsDenseArray(subshape));
|
||||
DenseArrayToStringHelper(literal, shape_index, print_layout, pieces);
|
||||
DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
|
||||
pieces);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1169,10 +1176,27 @@ int64 LiteralBase::sparse_element_count() const {
|
||||
return sparse_indices()->index_count();
|
||||
}
|
||||
|
||||
string LiteralBase::ToString(bool print_layout) const {
|
||||
string LiteralBase::ToString() const {
|
||||
std::vector<string> pieces;
|
||||
CHECK(LayoutUtil::HasLayout(this->shape()));
|
||||
ToStringHelper(*this, {}, print_layout, &pieces);
|
||||
ToStringHelper(*this, {}, /*print_shape=*/true,
|
||||
/*print_layout=*/false, &pieces);
|
||||
return absl::StrJoin(pieces, "");
|
||||
}
|
||||
|
||||
string LiteralBase::ToStringWithoutShape() const {
|
||||
std::vector<string> pieces;
|
||||
CHECK(LayoutUtil::HasLayout(this->shape()));
|
||||
ToStringHelper(*this, {}, /*print_shape=*/false,
|
||||
/*print_layout=*/false, &pieces);
|
||||
return absl::StrJoin(pieces, "");
|
||||
}
|
||||
|
||||
string LiteralBase::ToStringWithLayout() const {
|
||||
std::vector<string> pieces;
|
||||
CHECK(LayoutUtil::HasLayout(this->shape()));
|
||||
ToStringHelper(*this, {}, /*print_shape=*/true,
|
||||
/*print_layout=*/true, &pieces);
|
||||
return absl::StrJoin(pieces, "");
|
||||
}
|
||||
|
||||
|
@ -92,9 +92,20 @@ class LiteralBase {
|
||||
// array.
|
||||
string GetR1U8AsString() const;
|
||||
|
||||
// Returns a string representation of the literal value.
|
||||
// Warning: this function can take minutes for multi-million element Literals.
|
||||
string ToString(bool print_layout = false) const;
|
||||
// Returns a string representation of the literal value. The Shape of the
|
||||
// literal is a prefix of the literal value in the string.
|
||||
|
||||
// Warning: this function can take minutes for multi-million
|
||||
// element Literals.
|
||||
string ToString() const;
|
||||
|
||||
// Returns a string representation of the literal value which does *not*
|
||||
// include the shape string.
|
||||
string ToStringWithoutShape() const;
|
||||
|
||||
// Returns a string representation of the literal value which includes the
|
||||
// shape string with its layout.does *not* include the shape string.
|
||||
string ToStringWithLayout() const;
|
||||
|
||||
// Gets an element in the literal at the given index. The multi_index is
|
||||
// CHECKed against the dimension sizes.
|
||||
|
@ -98,42 +98,42 @@ class LiteralUtilTest : public ::testing::Test {
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralScalarToString) {
|
||||
auto true_lit = LiteralUtil::CreateR0<bool>(true);
|
||||
EXPECT_EQ("true", true_lit.ToString());
|
||||
EXPECT_EQ("pred[] true", true_lit.ToString());
|
||||
|
||||
auto false_lit = LiteralUtil::CreateR0<bool>(false);
|
||||
EXPECT_EQ("false", false_lit.ToString());
|
||||
EXPECT_EQ("pred[] false", false_lit.ToString());
|
||||
|
||||
auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
|
||||
EXPECT_EQ("42", u32_lit.ToString());
|
||||
EXPECT_EQ("u32[] 42", u32_lit.ToString());
|
||||
|
||||
auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
|
||||
EXPECT_EQ("-999", s32_lit.ToString());
|
||||
EXPECT_EQ("s32[] -999", s32_lit.ToString());
|
||||
|
||||
auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
|
||||
EXPECT_EQ("3.14", f32_lit.ToString());
|
||||
EXPECT_EQ("f32[] 3.14", f32_lit.ToString());
|
||||
|
||||
auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
|
||||
EXPECT_EQ("0.5", f16_lit.ToString());
|
||||
EXPECT_EQ("f16[] 0.5", f16_lit.ToString());
|
||||
|
||||
auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
|
||||
EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString());
|
||||
EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString());
|
||||
|
||||
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
|
||||
EXPECT_EQ("0.5", bf16_lit.ToString());
|
||||
EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString());
|
||||
|
||||
// 3.14 will be rounded to 3.14062 in bfloat16 format.
|
||||
auto bf16_lit_truncated =
|
||||
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
|
||||
ASSERT_EQ("3.14062", bf16_lit_truncated.ToString());
|
||||
ASSERT_EQ("bf16[] 3.14062", bf16_lit_truncated.ToString());
|
||||
|
||||
auto bf16_lit_truncated2 =
|
||||
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
|
||||
EXPECT_EQ("9", bf16_lit_truncated2.ToString());
|
||||
EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString());
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralVectorToString) {
|
||||
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
|
||||
EXPECT_EQ("{1, 0, 1}", pred_vec.ToString());
|
||||
EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString());
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, R2ToString) {
|
||||
@ -210,8 +210,8 @@ TEST_F(LiteralUtilTest, TupleToString) {
|
||||
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
||||
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
|
||||
const string expected = R"((f32[], f32[2,2]) (
|
||||
1,
|
||||
const string expected = R"((
|
||||
f32[] 1,
|
||||
f32[2,2] {
|
||||
{ 1, 2 },
|
||||
{ 3, 4 }
|
||||
@ -1890,7 +1890,7 @@ TEST_F(LiteralUtilTest, SortSparseElements) {
|
||||
literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
|
||||
literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
|
||||
literal.SortSparseElements();
|
||||
EXPECT_EQ(literal.ToString(false),
|
||||
EXPECT_EQ(literal.ToString(),
|
||||
"f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -90,5 +93,65 @@ bool IsArrayType(PrimitiveType primitive_type) {
|
||||
primitive_type != OPAQUE && primitive_type != TOKEN;
|
||||
}
|
||||
|
||||
// Class to memoize the computation of
|
||||
// absl::AsciiStrToLower(PrimitiveType_Name(p))
|
||||
// for all PrimitiveType values "p"
|
||||
class PrimitiveTypeNameGenerator {
|
||||
public:
|
||||
PrimitiveTypeNameGenerator() {
|
||||
for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
|
||||
if (PrimitiveType_IsValid(i)) {
|
||||
lowercase_name_[i] = absl::AsciiStrToLower(
|
||||
PrimitiveType_Name(static_cast<PrimitiveType>(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
const string& LowercaseName(PrimitiveType t) {
|
||||
return lowercase_name_[static_cast<int>(t)];
|
||||
}
|
||||
|
||||
private:
|
||||
string lowercase_name_[PrimitiveType_ARRAYSIZE];
|
||||
};
|
||||
|
||||
const string& LowercasePrimitiveTypeName(PrimitiveType s) {
|
||||
static auto* gen = new PrimitiveTypeNameGenerator();
|
||||
return gen->LowercaseName(s);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns a map from lower-case primitive type name to primitive type.
|
||||
const std::unordered_map<string, PrimitiveType>& GetPrimitiveTypeStringMap() {
|
||||
static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
|
||||
static auto* map = new std::unordered_map<string, PrimitiveType>;
|
||||
for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
|
||||
if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) {
|
||||
auto value = static_cast<PrimitiveType>(i);
|
||||
(*map)[LowercasePrimitiveTypeName(value)] = value;
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}();
|
||||
return *name_to_type;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) {
|
||||
const auto& map = GetPrimitiveTypeStringMap();
|
||||
auto found = map.find(string(name));
|
||||
if (found == map.end()) {
|
||||
return InvalidArgument("Invalid element type string: \"%s\".", name);
|
||||
}
|
||||
return found->second;
|
||||
}
|
||||
|
||||
bool IsPrimitiveTypeName(absl::string_view name) {
|
||||
const auto& map = GetPrimitiveTypeStringMap();
|
||||
auto found = map.find(string(name));
|
||||
return found != map.end();
|
||||
}
|
||||
|
||||
} // namespace primitive_util
|
||||
} // namespace xla
|
||||
|
@ -20,6 +20,9 @@ limitations under the License.
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
@ -221,6 +224,17 @@ template <>
|
||||
struct PrimitiveTypeToNative<C64> {
|
||||
using type = complex64;
|
||||
};
|
||||
|
||||
// Returns the lower-case name of the given primitive type.
|
||||
const string& LowercasePrimitiveTypeName(PrimitiveType s);
|
||||
|
||||
// Returns the PrimitiveType matching the given name. The given name is expected
|
||||
// to be lower-case.
|
||||
StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name);
|
||||
|
||||
// Returns true if the given name is a primitive type string (lower-case).
|
||||
bool IsPrimitiveTypeName(absl::string_view name);
|
||||
|
||||
} // namespace primitive_util
|
||||
} // namespace xla
|
||||
|
||||
|
46
tensorflow/compiler/xla/primitive_util_test.cc
Normal file
46
tensorflow/compiler/xla/primitive_util_test.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
|
||||
#include <numeric>
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(PrimitiveUtilTest, StringToPrimitiveType) {
|
||||
auto expect_ok_and_equal = [](const string& str, PrimitiveType expected) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(PrimitiveType actual,
|
||||
primitive_util::StringToPrimitiveType(str));
|
||||
EXPECT_EQ(expected, actual);
|
||||
};
|
||||
expect_ok_and_equal("f32", F32);
|
||||
expect_ok_and_equal("tuple", TUPLE);
|
||||
expect_ok_and_equal("pred", PRED);
|
||||
expect_ok_and_equal("s32", S32);
|
||||
|
||||
EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("F32").status());
|
||||
EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("Pred").status());
|
||||
EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("preD").status());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -1014,6 +1014,7 @@ cc_library(
|
||||
srcs = ["name_uniquer.cc"],
|
||||
hdrs = ["name_uniquer.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -1785,6 +1786,7 @@ tf_cc_test(
|
||||
":hlo_cse",
|
||||
":hlo_dce",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":hlo_pass",
|
||||
":hlo_pass_pipeline",
|
||||
":tuple_simplifier",
|
||||
@ -3628,7 +3630,6 @@ cc_library(
|
||||
srcs = ["hlo_lexer.cc"],
|
||||
hdrs = [
|
||||
"hlo_lexer.h",
|
||||
"hlo_token.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -32,8 +32,8 @@ HloModule foobar
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
||||
%p = f32[2,2] parameter(0)
|
||||
%constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
%constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
|
||||
}
|
||||
)";
|
||||
@ -91,7 +91,7 @@ HloModule foobar
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) {
|
||||
%p = f32[2,2] parameter(0)
|
||||
%constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
%tuple1 = (f32[2,2]) tuple(%constant.f32)
|
||||
%tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
|
||||
ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2)
|
||||
@ -152,7 +152,7 @@ HloModule foobar
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
||||
%p = f32[2,2] parameter(0)
|
||||
%constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
%tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
|
||||
%get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
|
||||
%get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0
|
||||
@ -174,7 +174,7 @@ HloModule foobar
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
||||
%p = f32[2,2] parameter(0)
|
||||
%constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
%tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
|
||||
%get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
|
||||
%get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
|
||||
@ -196,8 +196,8 @@ HloModule foobar
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
||||
%p = f32[2,2] parameter(0)
|
||||
%constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}})
|
||||
%constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
%constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}})
|
||||
%tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
|
||||
%get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
|
||||
%get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
|
||||
@ -226,7 +226,7 @@ HloModule foobar
|
||||
|
||||
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
||||
%x = (f32[2,2], f32[2,2]) parameter(0)
|
||||
%constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
%get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
|
||||
%get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
|
||||
%add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
|
||||
@ -235,7 +235,7 @@ HloModule foobar
|
||||
}
|
||||
|
||||
ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
|
||||
%constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}})
|
||||
%constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
|
||||
%init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
|
||||
ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
|
||||
}
|
||||
@ -263,7 +263,7 @@ HloModule foobar
|
||||
|
||||
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
||||
%x = (f32[2,2], f32[2,2]) parameter(0)
|
||||
%constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
%get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
|
||||
%get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
|
||||
%add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
|
||||
@ -272,8 +272,8 @@ HloModule foobar
|
||||
}
|
||||
|
||||
ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
|
||||
%constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}})
|
||||
%constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}})
|
||||
%constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}})
|
||||
%constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}})
|
||||
%init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
|
||||
ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
|
||||
}
|
||||
@ -301,8 +301,8 @@ HloModule foobar
|
||||
|
||||
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
||||
%x = (f32[2,2], f32[2,2]) parameter(0)
|
||||
%constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}})
|
||||
%constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
%constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}})
|
||||
%get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
|
||||
%get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
|
||||
%add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1)
|
||||
@ -311,7 +311,7 @@ HloModule foobar
|
||||
}
|
||||
|
||||
ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
|
||||
%constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}})
|
||||
%constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
|
||||
%init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
|
||||
ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
|
||||
}
|
||||
|
@ -112,10 +112,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) {
|
||||
const string hlo_string = R"(
|
||||
HloModule TestTaskParallel_infeed_outfeed
|
||||
ENTRY InfeedOutfeed {
|
||||
token = token[] after-all()
|
||||
infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token)
|
||||
token0 = token[] after-all()
|
||||
infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token0)
|
||||
infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0
|
||||
ROOT outfeed0 = token[] outfeed(infeed0.data, token)
|
||||
ROOT outfeed0 = token[] outfeed(infeed0.data, token0)
|
||||
}
|
||||
)";
|
||||
|
||||
|
@ -31,29 +31,27 @@ HloModule RepeatedConstants
|
||||
while_body {
|
||||
arg_body = f32[2,3,2] parameter(0)
|
||||
ROOT const = f32[2,3,2] constant(
|
||||
f32[2,3,2]
|
||||
{{{1, 2}, {1001, 1002}, {2001, 2002}},
|
||||
{{2, 1}, {2001, 3002}, {2001, 2002}}})
|
||||
}
|
||||
|
||||
while_cond {
|
||||
arg_cond = f32[2,3,2] parameter(0)
|
||||
token = token[] after-all()
|
||||
infeed = (pred[], token[]) infeed(token)
|
||||
token0 = token[] after-all()
|
||||
infeed = (pred[], token[]) infeed(token0)
|
||||
ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[2,3,2] parameter(0)
|
||||
const_a = f32[2,3,2] constant(
|
||||
f32[2,3,2]
|
||||
{{{1, 2}, {1001, 1002}, {2001, 2002}},
|
||||
{{2, 1}, {2001, 3002}, {2001, 2002}}})
|
||||
const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body
|
||||
|
||||
token = token[] after-all()
|
||||
out0 = token[] outfeed(f32[2,3,2] const_a, token[] token)
|
||||
ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token)
|
||||
token0 = token[] after-all()
|
||||
out0 = token[] outfeed(f32[2,3,2] const_a, token[] token0)
|
||||
ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token0)
|
||||
}
|
||||
)";
|
||||
|
||||
@ -82,24 +80,24 @@ HloModule RepeatedConstants
|
||||
|
||||
while_body {
|
||||
arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0)
|
||||
ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} ))
|
||||
ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant(({ { 1 }, { 2 } }, {2} ))
|
||||
}
|
||||
|
||||
while_cond {
|
||||
arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0)
|
||||
token = token[] after-all()
|
||||
infeed = (pred[], token[]) infeed(token)
|
||||
token0 = token[] after-all()
|
||||
infeed = (pred[], token[]) infeed(token0)
|
||||
ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[2,3,2] parameter(0)
|
||||
const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} ))
|
||||
const_a = (f32[2,1]{1,0}, f32[1]{0}) constant(( { { 1 }, { 2 } }, {2} ))
|
||||
const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body
|
||||
|
||||
token = token[] after-all()
|
||||
out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token)
|
||||
ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token)
|
||||
token0 = token[] after-all()
|
||||
out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token0)
|
||||
ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token0)
|
||||
}
|
||||
)";
|
||||
|
||||
|
@ -28,12 +28,11 @@ HloModule Outfeed
|
||||
|
||||
ENTRY main {
|
||||
const_a = f32[2,3,2] constant(
|
||||
f32[2,3,2]
|
||||
{{{1, 2}, {1001, 1002}, {2001, 2002}},
|
||||
{{2, 1}, {2001, 3002}, {2001, 2002}}})
|
||||
|
||||
token = token[] after-all()
|
||||
outfeed = token[] outfeed(f32[2,3,2] const_a, token)
|
||||
token0 = token[] after-all()
|
||||
outfeed = token[] outfeed(f32[2,3,2] const_a, token0)
|
||||
ROOT root = () tuple()
|
||||
}
|
||||
)";
|
||||
|
@ -599,7 +599,7 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) {
|
||||
Array4D<float> constant_arr(4, 4, 2, 2);
|
||||
constant_arr.FillIota(0);
|
||||
string constant_str =
|
||||
LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
|
||||
LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape();
|
||||
|
||||
const string module_str = absl::StrFormat(R"(
|
||||
HloModule test
|
||||
|
@ -369,7 +369,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule SortLayout
|
||||
ENTRY sort {
|
||||
keys = f32[3,2]{0,1} constant(f32[3,2]{0,1}{{0,1},{0,1},{0,1}})
|
||||
keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}})
|
||||
values = f32[2,3]{1,0} parameter(0)
|
||||
transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0}
|
||||
ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose),
|
||||
|
@ -252,7 +252,7 @@ const char* const kConstantFoldLargePad = R"(
|
||||
HloModule ConstantFoldLargePad
|
||||
|
||||
ENTRY r {
|
||||
a = f32[1,1,1] constant(f32[1,1,1]{{{7}}})
|
||||
a = f32[1,1,1] constant({{{7}}})
|
||||
b = f32[] constant(42)
|
||||
ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63
|
||||
})";
|
||||
|
@ -1882,8 +1882,8 @@ TEST_P(HloDataflowAnalysisTest, AddDependency) {
|
||||
HloModule AddDependency
|
||||
ENTRY %AddDependency (p: f32[3]) -> f32[3] {
|
||||
%p = f32[3] parameter(0)
|
||||
%token = token[] after-all()
|
||||
ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token)
|
||||
%token0 = token[] after-all()
|
||||
ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -195,10 +195,10 @@ HloModule Module
|
||||
ENTRY entry {
|
||||
p0 = (f32[4]) parameter(0)
|
||||
a = f32[4] get-tuple-element(p0), index=0
|
||||
token = token[] after-all()
|
||||
b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0}
|
||||
token0 = token[] after-all()
|
||||
b = (f32[4], u32[], token[]) send(a, token0), channel_id=1, sharding={maximal device=0}
|
||||
c = token[] send-done(b), channel_id=1, sharding={maximal device=0}
|
||||
d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0}
|
||||
d = (f32[4], u32[], token[]) recv(token0), channel_id=2, sharding={maximal device=0}
|
||||
e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0}
|
||||
e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0}
|
||||
f = f32[4] add(a, e_element)
|
||||
@ -235,12 +235,12 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) {
|
||||
HloModule Module
|
||||
|
||||
ENTRY entry {
|
||||
token = token[] after-all(), sharding={maximal device=-1}
|
||||
a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1}
|
||||
token0 = token[] after-all(), sharding={maximal device=-1}
|
||||
a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=-1}
|
||||
b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1}
|
||||
b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1}
|
||||
c = f32[4] add(b_element, b_element), sharding={maximal device=-1}
|
||||
d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1}
|
||||
d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=-1}
|
||||
ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1}
|
||||
}
|
||||
)";
|
||||
@ -259,12 +259,12 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) {
|
||||
HloModule Module
|
||||
|
||||
ENTRY entry {
|
||||
token = token[] after-all(), sharding={maximal device=0}
|
||||
a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0}
|
||||
token0 = token[] after-all(), sharding={maximal device=0}
|
||||
a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=0}
|
||||
b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0}
|
||||
b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0}
|
||||
c = f32[4] add(b_element, b_element)
|
||||
d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0}
|
||||
d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=0}
|
||||
ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0}
|
||||
}
|
||||
)";
|
||||
@ -344,8 +344,8 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) {
|
||||
HloModule Module
|
||||
|
||||
ENTRY entry {
|
||||
token = token[] after-all()
|
||||
infeed = ((f32[4], f32[4]), token[]) infeed(token),
|
||||
token0 = token[] after-all()
|
||||
infeed = ((f32[4], f32[4]), token[]) infeed(token0),
|
||||
sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}}
|
||||
infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0,
|
||||
sharding={{maximal device=1}, {maximal device=0}}
|
||||
|
@ -57,10 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule InfeedOutfeed
|
||||
ENTRY RoundTrip16MiBR1.v2 {
|
||||
token = token[] after-all()
|
||||
infeed = (bf16[4]{0}, token[]) infeed(token)
|
||||
token0 = token[] after-all()
|
||||
infeed = (bf16[4]{0}, token[]) infeed(token0)
|
||||
ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0
|
||||
outfeed = token[] outfeed(infeed.data, token)
|
||||
outfeed = token[] outfeed(infeed.data, token0)
|
||||
}
|
||||
)";
|
||||
auto module = CreateModuleFromHloString(hlo_string);
|
||||
@ -96,13 +96,13 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule BatchNormGrad
|
||||
ENTRY BatchNormGrad.v6 {
|
||||
constant.4 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/
|
||||
constant.4 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/
|
||||
{ /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } }, { /*i0=1*/ { /*i1=0*/ {0},
|
||||
{0} }, { /*i1=1*/ {0}, {0} } } })
|
||||
constant.5 = bf16[2]{0} constant({1, 1})
|
||||
constant.6 = bf16[2]{0} constant({0, 0})
|
||||
constant.7 = bf16[2]{0} constant({1, 1})
|
||||
constant.8 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/
|
||||
constant.8 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/
|
||||
{ /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} } }, { /*i0=1*/ { /*i1=0*/
|
||||
{5}, {6} }, { /*i1=1*/ {7}, {8} } } })
|
||||
ROOT batch-norm-grad = (bf16[2,2,2,1]{3,2,1,0}, bf16[2]{0}, bf16[2]{0})
|
||||
|
@ -905,7 +905,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
|
||||
options.print_large_constants())) {
|
||||
// Literal::ToString emits multidimensional arrays over multiple
|
||||
// lines. Compact this into one line by stripping out white space.
|
||||
string tmp = literal().ToString();
|
||||
string tmp = literal().ToStringWithoutShape();
|
||||
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
|
||||
std::vector<string> v = absl::StrSplit(tmp, ' ');
|
||||
bool first = true;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -82,9 +83,23 @@ tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
|
||||
return tensorflow::RegexpStringPiece(begin, end - begin);
|
||||
}
|
||||
|
||||
TokKind HloLexer::LookAhead() {
|
||||
if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) {
|
||||
return GetKind();
|
||||
}
|
||||
|
||||
const char* old_current_ptr = current_ptr_;
|
||||
TokenState old_token_state = token_state_;
|
||||
Lex();
|
||||
TokKind kind = GetKind();
|
||||
token_state_ = old_token_state;
|
||||
current_ptr_ = old_current_ptr;
|
||||
return kind;
|
||||
}
|
||||
|
||||
TokKind HloLexer::LexToken() {
|
||||
while (true) {
|
||||
token_start_ = current_ptr_;
|
||||
token_state_.token_start = current_ptr_;
|
||||
|
||||
int current_char = GetNextChar();
|
||||
switch (current_char) {
|
||||
@ -206,43 +221,37 @@ TokKind HloLexer::LexToken() {
|
||||
// dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
|
||||
// identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
|
||||
TokKind HloLexer::LexIdentifier() {
|
||||
{
|
||||
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
|
||||
// 'consumable' will be advanced iff its prefix matches the pattern.
|
||||
static LazyRE2 shape_pattern = {
|
||||
R"(^(\w*\d*)\[([\d,\s]*)\](?:(dense|sparse)?{([\d,\s]+)})?)"};
|
||||
if (RE2::Consume(&consumable, *shape_pattern)) {
|
||||
auto status_or_shape = ShapeUtil::ParseShapeString(
|
||||
StringPieceFromPointers(token_start_, consumable.begin()));
|
||||
if (status_or_shape.ok()) {
|
||||
// This is a shape string.
|
||||
shape_val_ = status_or_shape.ValueOrDie();
|
||||
current_ptr_ = consumable.begin();
|
||||
return TokKind::kShape;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (IsIdentifierChar(PeekCurrentChar())) {
|
||||
current_ptr_++;
|
||||
}
|
||||
|
||||
// If followed by ':', it's a name.
|
||||
if (PeekCurrentChar() == ':') {
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
token_state_.str_val.assign(token_state_.token_start, current_ptr_);
|
||||
current_ptr_++; // skip ':'
|
||||
return TokKind::kName;
|
||||
}
|
||||
|
||||
// If followed by '=', it's a attribute name.
|
||||
if (PeekCurrentChar() == '=') {
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
token_state_.str_val.assign(token_state_.token_start, current_ptr_);
|
||||
current_ptr_++; // skip '='
|
||||
return TokKind::kAttributeName;
|
||||
}
|
||||
|
||||
absl::string_view identifier =
|
||||
StringPieceFromPointers(token_start_, current_ptr_);
|
||||
StringPieceFromPointers(token_state_.token_start, current_ptr_);
|
||||
|
||||
// Primitive type strings are reserved words. The exception is 'tuple' whose
|
||||
// type is represented using nested parentheses without the string 'tuple'.
|
||||
if (primitive_util::IsPrimitiveTypeName(identifier)) {
|
||||
PrimitiveType primitive_type =
|
||||
primitive_util::StringToPrimitiveType(identifier).ValueOrDie();
|
||||
if (primitive_type != TUPLE) {
|
||||
token_state_.primitive_type_val = primitive_type;
|
||||
return TokKind::kPrimitiveType;
|
||||
}
|
||||
}
|
||||
|
||||
// See if this is a keyword.
|
||||
#define KEYWORD(STR) \
|
||||
@ -261,21 +270,23 @@ TokKind HloLexer::LexIdentifier() {
|
||||
KEYWORD(ROOT);
|
||||
KEYWORD(maximal);
|
||||
KEYWORD(replicated);
|
||||
KEYWORD(sparse);
|
||||
|
||||
#undef KEYWORD
|
||||
|
||||
{
|
||||
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
|
||||
auto consumable =
|
||||
RegexpStringPieceFromPointers(token_state_.token_start, buf_.end());
|
||||
static LazyRE2 dim_labels_pattern = {
|
||||
R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
|
||||
if (RE2::Consume(&consumable, *dim_labels_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
token_state_.str_val.assign(token_state_.token_start, current_ptr_);
|
||||
return TokKind::kDimLabels;
|
||||
}
|
||||
}
|
||||
|
||||
str_val_ = string(identifier);
|
||||
token_state_.str_val = string(identifier);
|
||||
return TokKind::kIdent;
|
||||
}
|
||||
|
||||
@ -289,7 +300,7 @@ TokKind HloLexer::LexPercent() {
|
||||
while (IsIdentifierChar(PeekCurrentChar())) {
|
||||
current_ptr_++;
|
||||
}
|
||||
str_val_.assign(name_start, current_ptr_);
|
||||
token_state_.str_val.assign(name_start, current_ptr_);
|
||||
return TokKind::kName;
|
||||
}
|
||||
return TokKind::kError;
|
||||
@ -307,12 +318,14 @@ TokKind HloLexer::LexPercent() {
|
||||
// int ::= [-]?[0-9]+
|
||||
// negative inf ::= '-inf'
|
||||
TokKind HloLexer::LexNumberOrPattern() {
|
||||
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
|
||||
auto consumable =
|
||||
RegexpStringPieceFromPointers(token_state_.token_start, buf_.end());
|
||||
static LazyRE2 float_pattern = {
|
||||
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
|
||||
if (RE2::Consume(&consumable, *float_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_));
|
||||
CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_),
|
||||
&token_state_.decimal_val));
|
||||
return TokKind::kDecimal;
|
||||
}
|
||||
|
||||
@ -324,27 +337,28 @@ TokKind HloLexer::LexNumberOrPattern() {
|
||||
|
||||
if (RE2::Consume(&consumable, *dim_labels_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
token_state_.str_val.assign(token_state_.token_start, current_ptr_);
|
||||
return TokKind::kDimLabels;
|
||||
}
|
||||
|
||||
if (RE2::Consume(&consumable, *dxd_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
token_state_.str_val.assign(token_state_.token_start, current_ptr_);
|
||||
return TokKind::kDxD;
|
||||
}
|
||||
|
||||
if (RE2::Consume(&consumable, *pad_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
token_state_.str_val.assign(token_state_.token_start, current_ptr_);
|
||||
return TokKind::kPad;
|
||||
}
|
||||
|
||||
static LazyRE2 int_pattern = {R"([-]?\d+)"};
|
||||
if (RE2::Consume(&consumable, *int_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
auto slice = StringPieceFromPointers(token_start_, current_ptr_);
|
||||
if (absl::SimpleAtoi(slice, &int64_val_)) {
|
||||
auto slice =
|
||||
StringPieceFromPointers(token_state_.token_start, current_ptr_);
|
||||
if (absl::SimpleAtoi(slice, &token_state_.int64_val)) {
|
||||
return TokKind::kInt;
|
||||
}
|
||||
LOG(ERROR) << "Failed to parse int literal: " << slice;
|
||||
@ -403,16 +417,17 @@ absl::string_view HloLexer::GetLine(LocTy loc) const {
|
||||
}
|
||||
|
||||
// Lexes quoted string with escaping characters. If matched, the quoted string
|
||||
// will be unescaped and stored to str_val_.
|
||||
// will be unescaped and stored to token_state_.str_val.
|
||||
TokKind HloLexer::LexString() {
|
||||
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
|
||||
auto consumable =
|
||||
RegexpStringPieceFromPointers(token_state_.token_start, buf_.end());
|
||||
static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
|
||||
if (RE2::Consume(&consumable, *escaping_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
absl::string_view raw =
|
||||
StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1);
|
||||
StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1);
|
||||
string error;
|
||||
if (!absl::CUnescape(raw, &str_val_, &error)) {
|
||||
if (!absl::CUnescape(raw, &token_state_.str_val, &error)) {
|
||||
LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
|
||||
return TokKind::kError;
|
||||
}
|
||||
@ -467,6 +482,10 @@ string TokKindToString(TokKind kind) {
|
||||
return "kw_inf";
|
||||
case TokKind::kNegInf:
|
||||
return "kNegInf";
|
||||
case TokKind::kw_sparse:
|
||||
return "kw_sparse";
|
||||
case TokKind::kPrimitiveType:
|
||||
return "kPrimitiveType";
|
||||
case TokKind::kName:
|
||||
return "kName";
|
||||
case TokKind::kAttributeName:
|
||||
@ -481,8 +500,6 @@ string TokKindToString(TokKind kind) {
|
||||
return "kIdent";
|
||||
case TokKind::kString:
|
||||
return "kString";
|
||||
case TokKind::kShape:
|
||||
return "kShape";
|
||||
case TokKind::kInt:
|
||||
return "kInt";
|
||||
case TokKind::kDecimal:
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_token.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -29,6 +28,57 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Defines different kinds of tokens used by the HLO lexer.
|
||||
//
|
||||
// You shouldn't need to use this directly unless you're using HloLexer
|
||||
// directly, and you probably don't need to do that. Use hlo_parser instead.
|
||||
enum class TokKind {
|
||||
// Markers
|
||||
kEof,
|
||||
kError,
|
||||
|
||||
// Tokens with no info.
|
||||
kEqual, // =
|
||||
kComma, // ,
|
||||
kColon, // :
|
||||
kLsquare,
|
||||
kRsquare, // [ ]
|
||||
kLbrace,
|
||||
kRbrace, // { }
|
||||
kLparen,
|
||||
kRparen, // ( )
|
||||
|
||||
kArrow, // ->
|
||||
|
||||
// Keywords
|
||||
kw_HloModule,
|
||||
kw_ENTRY,
|
||||
kw_ROOT,
|
||||
kw_true,
|
||||
kw_false,
|
||||
kw_maximal,
|
||||
kw_replicated,
|
||||
kw_nan,
|
||||
kw_inf,
|
||||
kw_sparse,
|
||||
|
||||
kNegInf, // -inf
|
||||
|
||||
// Typed tokens.
|
||||
kPrimitiveType, // F32, PRED, etc.
|
||||
kName, // %foo
|
||||
kAttributeName, // dimensions=
|
||||
kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
|
||||
kDxD, // [0-9]+(x[0-9]+)+
|
||||
kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*
|
||||
kIdent, // other identifiers
|
||||
kString, // "abcd\"\n"
|
||||
kInt, // 42
|
||||
kDecimal, // 4.2
|
||||
};
|
||||
|
||||
string TokKindToString(TokKind kind);
|
||||
|
||||
// Lexer for the HloModule::ToString() format text.
|
||||
//
|
||||
// This class is meant to be used by hlo_parser.cc. You shouldn't need to use
|
||||
@ -39,9 +89,9 @@ class HloLexer {
|
||||
current_ptr_ = buf_.begin();
|
||||
}
|
||||
|
||||
TokKind Lex() { return current_kind_ = LexToken(); }
|
||||
TokKind Lex() { return token_state_.current_kind = LexToken(); }
|
||||
|
||||
TokKind GetKind() const { return current_kind_; }
|
||||
TokKind GetKind() const { return token_state_.current_kind; }
|
||||
string GetStrVal() const {
|
||||
switch (GetKind()) {
|
||||
case TokKind::kName:
|
||||
@ -51,28 +101,28 @@ class HloLexer {
|
||||
case TokKind::kPad:
|
||||
case TokKind::kString:
|
||||
case TokKind::kIdent:
|
||||
return str_val_;
|
||||
return token_state_.str_val;
|
||||
default:
|
||||
LOG(FATAL) << "This token does not have string value";
|
||||
}
|
||||
}
|
||||
Shape GetShapeVal() const {
|
||||
CHECK(GetKind() == TokKind::kShape);
|
||||
return shape_val_;
|
||||
}
|
||||
tensorflow::int64 GetInt64Val() const {
|
||||
CHECK(GetKind() == TokKind::kInt);
|
||||
return int64_val_;
|
||||
return token_state_.int64_val;
|
||||
}
|
||||
double GetDecimalVal() const {
|
||||
CHECK(GetKind() == TokKind::kDecimal);
|
||||
return decimal_val_;
|
||||
return token_state_.decimal_val;
|
||||
}
|
||||
PrimitiveType GetPrimitiveTypeVal() const {
|
||||
CHECK(GetKind() == TokKind::kPrimitiveType);
|
||||
return token_state_.primitive_type_val;
|
||||
}
|
||||
|
||||
typedef const char* LocTy;
|
||||
|
||||
// Returns the location of the current token.
|
||||
LocTy GetLoc() const { return token_start_; }
|
||||
LocTy GetLoc() const { return token_state_.token_start; }
|
||||
|
||||
// Returns the line and column of a location in the buffer.
|
||||
std::pair<unsigned, unsigned> GetLineAndColumn(LocTy location) const;
|
||||
@ -80,6 +130,9 @@ class HloLexer {
|
||||
// Returns the whole line given the location.
|
||||
absl::string_view GetLine(LocTy loc) const;
|
||||
|
||||
// Looks ahead one token and returns it. Lexer state is unchanged.
|
||||
TokKind LookAhead();
|
||||
|
||||
private:
|
||||
// Returns the current character. If it's neither the end of input buffer nor
|
||||
// an invalid character, moves the pointer forward.
|
||||
@ -112,12 +165,15 @@ class HloLexer {
|
||||
const char* current_ptr_;
|
||||
|
||||
// Information about the current token.
|
||||
const char* token_start_ = nullptr;
|
||||
TokKind current_kind_;
|
||||
string str_val_;
|
||||
Shape shape_val_;
|
||||
tensorflow::int64 int64_val_;
|
||||
double decimal_val_;
|
||||
struct TokenState {
|
||||
const char* token_start = nullptr;
|
||||
TokKind current_kind;
|
||||
string str_val;
|
||||
tensorflow::int64 int64_val;
|
||||
double decimal_val;
|
||||
PrimitiveType primitive_type_val;
|
||||
};
|
||||
TokenState token_state_;
|
||||
|
||||
struct LineNoCacheTy {
|
||||
const char* last_query;
|
||||
|
@ -403,9 +403,9 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
|
||||
HloModule OutfeedLoop
|
||||
WhileBody {
|
||||
body_param = (s32[]) parameter(0)
|
||||
token = token[] after-all()
|
||||
token0 = token[] after-all()
|
||||
constant.2 = s32[] constant(2)
|
||||
outfeed_tuple = (s32[]) outfeed(constant.2, token)
|
||||
outfeed_tuple = (s32[]) outfeed(constant.2, token0)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
|
||||
constant.1 = s32[] constant(1)
|
||||
add = s32[] add(get-tuple-element.1, constant.1)
|
||||
@ -436,9 +436,9 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
|
||||
HloModule OutfeedLoop
|
||||
InnerWhileBody {
|
||||
body_param = (s32[]) parameter(0)
|
||||
token = token[] after-all()
|
||||
token0 = token[] after-all()
|
||||
constant.2 = s32[] constant(2)
|
||||
outfeed_tuple = (s32[]) outfeed(constant.2, token)
|
||||
outfeed_tuple = (s32[]) outfeed(constant.2, token0)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
|
||||
constant.1 = s32[] constant(1)
|
||||
add = s32[] add(get-tuple-element.1, constant.1)
|
||||
|
@ -312,8 +312,8 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
|
||||
}
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
|
||||
absl::string_view shape) {
|
||||
return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(
|
||||
ShapeUtil::ParseShapeString(shape).ValueOrDie()));
|
||||
return ::testing::MakeMatcher(
|
||||
new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie()));
|
||||
}
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
|
||||
const class Shape& shape) {
|
||||
@ -323,7 +323,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
|
||||
absl::string_view shape) {
|
||||
return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
|
||||
ShapeUtil::ParseShapeString(shape).ValueOrDie()));
|
||||
ParseShape(shape).ValueOrDie()));
|
||||
}
|
||||
|
||||
// Verifies the value of the HloSharing against the provided sharding object.
|
||||
|
@ -373,9 +373,9 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) {
|
||||
HloModule OutfeedLoop
|
||||
WhileBody {
|
||||
body_param = (s32[]) parameter(0)
|
||||
token = token[] after-all()
|
||||
token0 = token[] after-all()
|
||||
constant.2 = s32[] constant(2)
|
||||
outfeed_tuple = (s32[]) outfeed(constant.2, token)
|
||||
outfeed_tuple = (s32[]) outfeed(constant.2, token0)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
|
||||
constant.1 = s32[] constant(1)
|
||||
add = s32[] add(get-tuple-element.1, constant.1)
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -74,6 +75,7 @@ class HloParser {
|
||||
string GetError() const { return StrJoin(error_, "\n"); }
|
||||
|
||||
// Stand alone parsing utils for various aggregate data types.
|
||||
StatusOr<Shape> ParseShapeOnly();
|
||||
StatusOr<HloSharding> ParseShardingOnly();
|
||||
StatusOr<Window> ParseWindowOnly();
|
||||
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
|
||||
@ -255,7 +257,9 @@ class HloParser {
|
||||
bool ParseName(string* result);
|
||||
bool ParseAttributeName(string* result);
|
||||
bool ParseString(string* result);
|
||||
bool ParseDimensionSizes(std::vector<int64>* dimension_sizes);
|
||||
bool ParseShape(Shape* result);
|
||||
bool ParseLayout(Layout* layout);
|
||||
bool ParseOpcode(HloOpcode* result);
|
||||
bool ParseFftType(FftType* result);
|
||||
bool ParseFusionKind(HloInstruction::FusionKind* result);
|
||||
@ -279,9 +283,6 @@ class HloParser {
|
||||
// If the current token is 'kind', eats it (i.e. lexes the next token) and
|
||||
// returns true.
|
||||
bool EatIfPresent(TokKind kind);
|
||||
// Parses a shape, and returns true if the result is compatible with the given
|
||||
// shape.
|
||||
bool EatShapeAndCheckCompatible(const Shape& shape);
|
||||
|
||||
// Adds the instruction to the pool. Returns false and emits an error if the
|
||||
// instruction already exists.
|
||||
@ -1697,11 +1698,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TokKind::kShape:
|
||||
// TODO(b/112302613): Left here for backward compatibility to ignore the
|
||||
// removed tile shape data.
|
||||
lexer_.Lex();
|
||||
break;
|
||||
case TokKind::kRbrace:
|
||||
break;
|
||||
default:
|
||||
@ -1925,19 +1921,6 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
|
||||
Shape new_shape;
|
||||
if (!ParseShape(&new_shape)) {
|
||||
return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape)));
|
||||
}
|
||||
if (!ShapeUtil::Compatible(shape, new_shape)) {
|
||||
return TokenError(StrCat(
|
||||
"expects shape ", ShapeUtil::HumanString(shape),
|
||||
", but sees a different shape: ", ShapeUtil::HumanString(new_shape)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// literal
|
||||
// ::= tuple
|
||||
// ::= non_tuple
|
||||
@ -1952,10 +1935,6 @@ bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
|
||||
// ::= /*empty*/
|
||||
// ::= literal (',' literal)*
|
||||
bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
|
||||
if (!EatShapeAndCheckCompatible(shape)) {
|
||||
return TokenError(StrCat("expects tuple constant in shape ",
|
||||
ShapeUtil::HumanString(shape)));
|
||||
}
|
||||
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
|
||||
return false;
|
||||
}
|
||||
@ -1990,16 +1969,12 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
|
||||
return ParseSparseLiteral(literal, shape);
|
||||
}
|
||||
|
||||
CHECK(LayoutUtil::IsDenseArray(shape));
|
||||
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
|
||||
return ParseDenseLiteral(literal, shape);
|
||||
}
|
||||
|
||||
bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
|
||||
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
|
||||
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create a literal with the given shape in default layout.
|
||||
*literal = LiteralUtil::CreateFromDimensions(
|
||||
shape.element_type(), AsInt64Slice(shape.dimensions()));
|
||||
@ -2126,10 +2101,6 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
|
||||
}
|
||||
|
||||
bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
|
||||
if (!EatShapeAndCheckCompatible(shape)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (shape.element_type()) {
|
||||
case PRED:
|
||||
return ParseSparseLiteralHelper<tensorflow::uint8>(literal, shape);
|
||||
@ -2994,6 +2965,39 @@ bool HloParser::ParseParamList() {
|
||||
return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
|
||||
}
|
||||
|
||||
// dimension_sizes ::= '[' int64_list ']'
|
||||
bool HloParser::ParseDimensionSizes(std::vector<int64>* dimension_sizes) {
|
||||
auto parse_and_add_item = [&]() {
|
||||
tensorflow::int64 i;
|
||||
if (!ParseInt64(&i)) {
|
||||
return false;
|
||||
}
|
||||
dimension_sizes->push_back(i);
|
||||
return true;
|
||||
};
|
||||
return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
|
||||
parse_and_add_item);
|
||||
}
|
||||
|
||||
// layout ::= '{' int64_list '}'
|
||||
bool HloParser::ParseLayout(Layout* layout) {
|
||||
std::vector<int64> minor_to_major;
|
||||
auto parse_and_add_item = [&]() {
|
||||
tensorflow::int64 i;
|
||||
if (!ParseInt64(&i)) {
|
||||
return false;
|
||||
}
|
||||
minor_to_major.push_back(i);
|
||||
return true;
|
||||
};
|
||||
if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
|
||||
parse_and_add_item)) {
|
||||
return false;
|
||||
}
|
||||
*layout = LayoutUtil::MakeLayout(minor_to_major);
|
||||
return true;
|
||||
}
|
||||
|
||||
// shape ::= shape_val_
|
||||
// shape ::= '(' tuple_elements ')'
|
||||
// tuple_elements
|
||||
@ -3017,19 +3021,61 @@ bool HloParser::ParseShape(Shape* result) {
|
||||
return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
|
||||
}
|
||||
|
||||
if (lexer_.GetKind() != TokKind::kShape) {
|
||||
return TokenError(absl::StrCat("expected shape, saw ",
|
||||
if (lexer_.GetKind() != TokKind::kPrimitiveType) {
|
||||
return TokenError(absl::StrCat("expected primitive type, saw ",
|
||||
TokKindToString(lexer_.GetKind())));
|
||||
}
|
||||
*result = lexer_.GetShapeVal();
|
||||
PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
|
||||
lexer_.Lex();
|
||||
|
||||
std::vector<int64> dimension_sizes;
|
||||
if (!ParseDimensionSizes(&dimension_sizes)) {
|
||||
return false;
|
||||
}
|
||||
result->set_element_type(primitive_type);
|
||||
*result->mutable_dimensions() = dimension_sizes;
|
||||
LayoutUtil::SetToDefaultLayout(result);
|
||||
|
||||
if (lexer_.GetKind() == TokKind::kw_sparse) {
|
||||
lexer_.Lex();
|
||||
const string message =
|
||||
"expects a brace-bracketed integer for sparse layout";
|
||||
tensorflow::int64 max_sparse_elements;
|
||||
if (!ParseToken(TokKind::kLbrace, message) ||
|
||||
!ParseInt64(&max_sparse_elements) ||
|
||||
!ParseToken(TokKind::kRbrace, message)) {
|
||||
return false;
|
||||
}
|
||||
*result->mutable_layout() =
|
||||
LayoutUtil::MakeSparseLayout(max_sparse_elements);
|
||||
return true;
|
||||
}
|
||||
|
||||
// We need to lookahead to see if a following open brace is the start of a
|
||||
// layout. The specific problematic case is:
|
||||
//
|
||||
// ENTRY %foo (x: f32[42]) -> f32[123] {
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// The open brace could either be the start of a computation or the start of a
|
||||
// layout for the f32[123] shape. We consider it the start of a layout if the
|
||||
// next token after the open brace is a integer
|
||||
if (lexer_.GetKind() == TokKind::kLbrace &&
|
||||
lexer_.LookAhead() == TokKind::kInt) {
|
||||
Layout layout;
|
||||
if (!ParseLayout(&layout)) {
|
||||
return false;
|
||||
}
|
||||
*result->mutable_layout() = layout;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::CanBeShape() {
|
||||
// A non-tuple shape starts with a kShape token; a tuple shape starts with
|
||||
// '('.
|
||||
return lexer_.GetKind() == TokKind::kShape ||
|
||||
// A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
|
||||
// with '('.
|
||||
return lexer_.GetKind() == TokKind::kPrimitiveType ||
|
||||
lexer_.GetKind() == TokKind::kLparen;
|
||||
}
|
||||
|
||||
@ -3332,6 +3378,18 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation,
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<Shape> HloParser::ParseShapeOnly() {
|
||||
lexer_.Lex();
|
||||
Shape shape;
|
||||
if (!ParseShape(&shape)) {
|
||||
return InvalidArgument("Syntax error:\n%s", GetError());
|
||||
}
|
||||
if (lexer_.GetKind() != TokKind::kEof) {
|
||||
return InvalidArgument("Syntax error:\nExtra content after shape");
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
StatusOr<HloSharding> HloParser::ParseShardingOnly() {
|
||||
lexer_.Lex();
|
||||
OpSharding op_sharding;
|
||||
@ -3475,4 +3533,9 @@ StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
|
||||
return parser.ParsePaddingConfigOnly();
|
||||
}
|
||||
|
||||
StatusOr<Shape> ParseShape(absl::string_view str) {
|
||||
HloParser parser(str);
|
||||
return parser.ParseShapeOnly();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -60,6 +60,9 @@ StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
|
||||
// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1".
|
||||
StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str);
|
||||
|
||||
// Parses and returns a Shape::ToString-format string.
|
||||
StatusOr<Shape> ParseShape(absl::string_view str);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
|
||||
|
@ -82,7 +82,7 @@ ENTRY %constant_pred () -> pred[] {
|
||||
R"(HloModule module
|
||||
|
||||
ENTRY %constant_pred_array () -> pred[2,3] {
|
||||
ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } })
|
||||
ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } })
|
||||
}
|
||||
|
||||
)"
|
||||
@ -128,7 +128,7 @@ ENTRY %ConstantF32Empty.v4 () -> f32[0] {
|
||||
R"(HloModule ConstantF32R4Empty_module
|
||||
|
||||
ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
|
||||
ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } })
|
||||
ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } })
|
||||
}
|
||||
|
||||
)"
|
||||
@ -139,7 +139,7 @@ ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
|
||||
R"(HloModule Small_3x2x1x1_module
|
||||
|
||||
ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] {
|
||||
ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
|
||||
ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
|
||||
}
|
||||
|
||||
)"
|
||||
@ -196,7 +196,7 @@ ENTRY %add_constants () -> f32[] {
|
||||
R"(HloModule TupleConstant_module
|
||||
|
||||
ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) {
|
||||
ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} ))
|
||||
ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} ))
|
||||
}
|
||||
|
||||
)"
|
||||
@ -295,11 +295,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
|
||||
R"(HloModule TwoSendRecvBothWayRecvFist_module
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
|
||||
%token = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1}
|
||||
%token0 = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
|
||||
ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
|
||||
%constant = f32[] constant(2.1), sharding={maximal device=0}
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
|
||||
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
|
||||
}
|
||||
|
||||
@ -310,11 +310,11 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
|
||||
R"(HloModule HostTransferSendRecv_module
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
|
||||
%token = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, is_host_transfer=true
|
||||
%token0 = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true
|
||||
ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true
|
||||
%constant = f32[] constant(2.1), sharding={maximal device=0}
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, is_host_transfer=true
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true
|
||||
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true
|
||||
}
|
||||
|
||||
@ -327,7 +327,7 @@ R"(HloModule GetTupleElement_module
|
||||
|
||||
ENTRY %GetTupleElement.v4 () -> s32[2,3] {
|
||||
%constant = f32[3]{0} constant({1, 2, 3})
|
||||
%constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } })
|
||||
%constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
|
||||
%tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1)
|
||||
ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0}
|
||||
}
|
||||
@ -434,7 +434,7 @@ ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f
|
||||
R"(HloModule Reverse4DFloatArrayOnDim01_module
|
||||
|
||||
ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
|
||||
%constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
|
||||
%constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
|
||||
ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
|
||||
}
|
||||
|
||||
@ -446,8 +446,8 @@ ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
|
||||
R"(HloModule Concat2x3With2x5_module
|
||||
|
||||
ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
|
||||
%constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } })
|
||||
%constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
|
||||
%constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } })
|
||||
%constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
|
||||
ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
|
||||
}
|
||||
|
||||
@ -471,8 +471,8 @@ R"(HloModule R4F32OverlapSmall_module
|
||||
}
|
||||
|
||||
ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
|
||||
%constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
|
||||
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
|
||||
%constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
|
||||
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
|
||||
%constant.2 = f32[] constant(0)
|
||||
ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
|
||||
}
|
||||
@ -523,7 +523,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
|
||||
R"(HloModule Slice3x3x3_To_1x3x3_F32_module
|
||||
|
||||
ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
|
||||
%constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
|
||||
%constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
|
||||
ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
|
||||
}
|
||||
|
||||
@ -547,7 +547,7 @@ ENTRY %SliceR0.v2 () -> s32[] {
|
||||
R"(HloModule Transpose_module
|
||||
|
||||
ENTRY %Transpose.v2 () -> s32[1,2,3] {
|
||||
%constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } })
|
||||
%constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } })
|
||||
ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
|
||||
}
|
||||
|
||||
@ -588,7 +588,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_
|
||||
R"(HloModule BasicTraining_module
|
||||
|
||||
ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
|
||||
%constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } })
|
||||
%constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } })
|
||||
%constant.1 = f32[2]{0} constant({2, 3})
|
||||
%constant.2 = f32[2]{0} constant({1, 2})
|
||||
ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
|
||||
@ -728,7 +728,7 @@ R"(HloModule fusion_module
|
||||
}
|
||||
|
||||
ENTRY %fusion.v3 () -> f32[3,2,1,1] {
|
||||
%constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
|
||||
%constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
|
||||
%constant.1 = f32[2]{0} constant({3.14, 4.25})
|
||||
ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
|
||||
}
|
||||
@ -740,7 +740,7 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] {
|
||||
R"(HloModule sparse_f32
|
||||
|
||||
ENTRY %sparse () -> f32[2,3,4] {
|
||||
ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3})
|
||||
ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3})
|
||||
}
|
||||
|
||||
)"
|
||||
@ -750,7 +750,7 @@ ENTRY %sparse () -> f32[2,3,4] {
|
||||
R"(HloModule sparse_f32_empty
|
||||
|
||||
ENTRY %sparse_f32_empty () -> f32[2,3,4] {
|
||||
ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{})
|
||||
ROOT %foo = f32[2,3,4]sparse{10} constant({})
|
||||
}
|
||||
|
||||
)"
|
||||
@ -760,7 +760,7 @@ ENTRY %sparse_f32_empty () -> f32[2,3,4] {
|
||||
R"(HloModule sparse_f32_r1
|
||||
|
||||
ENTRY %sparse_f32_r1 () -> f32[9] {
|
||||
ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6})
|
||||
ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6})
|
||||
}
|
||||
|
||||
)"
|
||||
@ -931,11 +931,11 @@ ENTRY reduce_entry {
|
||||
R"(HloModule outfeed_module
|
||||
|
||||
ENTRY InfeedToOutfeed {
|
||||
token = token[] after-all()
|
||||
infeed = ((u32[3]{0}, pred[]), token[]) infeed(token)
|
||||
token0 = token[] after-all()
|
||||
infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
|
||||
infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
|
||||
outfeed = token[] outfeed(infeed.data, token)
|
||||
ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token)
|
||||
outfeed = token[] outfeed(infeed.data, token0)
|
||||
ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0)
|
||||
infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
|
||||
infeed.1.token = token[] get-tuple-element(infeed.1), index=1
|
||||
outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
|
||||
@ -1266,8 +1266,8 @@ R"(HloModule AddDependency
|
||||
ENTRY AddDependency {
|
||||
p = f32[] parameter(0)
|
||||
neg = f32[] negate(p)
|
||||
token = token[] after-all(neg)
|
||||
p_after_token = f32[] add-dependency(p, token)
|
||||
token0 = token[] after-all(neg)
|
||||
p_after_token = f32[] add-dependency(p, token0)
|
||||
exp = f32[] exponential(p_after_token)
|
||||
ROOT sum = f32[] add(neg, exp)
|
||||
}
|
||||
@ -1419,7 +1419,7 @@ TEST_F(HloParserTest, MoreConstants) {
|
||||
|
||||
ENTRY %SelectScalarS32True.v4 () -> s32[] {
|
||||
%constant.2 = pred[] constant(true)
|
||||
%constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4}
|
||||
%constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4}
|
||||
%constant = s32[] constant(42)
|
||||
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
|
||||
}
|
||||
@ -1462,7 +1462,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_2) {
|
||||
const string original = R"(HloModule some_2x3_module
|
||||
|
||||
ENTRY %some_2x3 () -> f32[2,3] {
|
||||
ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6})
|
||||
ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6})
|
||||
}
|
||||
|
||||
)";
|
||||
@ -1476,7 +1476,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_3) {
|
||||
const string original = R"(HloModule some_2x3x2_module
|
||||
|
||||
ENTRY %some_2x3x2 () -> f32[2,3,2] {
|
||||
ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
|
||||
ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
|
||||
}
|
||||
|
||||
)";
|
||||
@ -1594,11 +1594,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) {
|
||||
const string original = R"(HloModule unexpected_attr_module
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%token = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15
|
||||
%token0 = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
|
||||
%recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
|
||||
ROOT %constant = f32[] constant(2.1)
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv
|
||||
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
|
||||
}
|
||||
|
||||
@ -1611,11 +1611,11 @@ TEST_F(HloParserTest, MissingAttribute) {
|
||||
const string original = R"(HloModule missing_attr_module
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%token = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15
|
||||
%token0 = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
|
||||
%recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
|
||||
ROOT %constant = f32[] constant(-2.1)
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token)
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
|
||||
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
|
||||
}
|
||||
|
||||
@ -1628,11 +1628,11 @@ TEST_F(HloParserTest, PredecessorUndefined) {
|
||||
const string original = R"(HloModule pre_not_found_module
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%token = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15
|
||||
%token0 = token[] after-all()
|
||||
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
|
||||
%recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
|
||||
ROOT %constant = f32[] constant(2.1)
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done}
|
||||
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done}
|
||||
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
|
||||
}
|
||||
|
||||
@ -1940,8 +1940,8 @@ TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
|
||||
TEST_F(HloParserTest, NontupleInfeed) {
|
||||
const string original = R"(HloModule nontuple_infeed:
|
||||
ENTRY nontuple_infeed {
|
||||
token = token[] after-all()
|
||||
ROOT infeed = pred[] infeed(token)
|
||||
token0 = token[] after-all()
|
||||
ROOT infeed = pred[] infeed(token0)
|
||||
})";
|
||||
ExpectHasSubstr(ParseHloString(original).status().error_message(),
|
||||
"infeed must have a non-empty tuple shape");
|
||||
@ -2239,7 +2239,7 @@ HloModule foobar
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
|
||||
%p = f32[2,2] parameter(0)
|
||||
%constant.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.1 = f32[2,2] constant({{1, 2}, {3, 4}})
|
||||
ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1)
|
||||
}
|
||||
)";
|
||||
@ -2249,7 +2249,85 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
|
||||
" with the shape of the operand instruction f32[2,2]{1,0}.");
|
||||
}
|
||||
|
||||
// custom call incompatible shape.
|
||||
TEST_F(HloParserTest, ParseShapeStringR2F32) {
|
||||
string shape_string = "f32[123,456]";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
|
||||
Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) {
|
||||
string shape_string = "(f32[1572864],s8[5120,1024])";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
|
||||
Shape expected =
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
|
||||
ShapeUtil::MakeShape(S8, {5120, 1024})});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseShapeStringNestedTuple) {
|
||||
string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
|
||||
Shape expected = ShapeUtil::MakeTupleShape({
|
||||
ShapeUtil::MakeShape(F32, {1}),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
|
||||
ShapeUtil::MakeOpaqueShape(),
|
||||
ShapeUtil::MakeShape(F32, {3}),
|
||||
});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseShapeStringWithLayout) {
|
||||
string shape_string = "f32[123,456]{0,1}";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
|
||||
Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) {
|
||||
string shape_string = "f32[123,456]sparse{10}";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
|
||||
Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10);
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseOpaqueType) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]"));
|
||||
Shape expected = ShapeUtil::MakeOpaqueShape();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseTokenType) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]"));
|
||||
Shape expected = ShapeUtil::MakeTokenShape();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseInvalidShapeString) {
|
||||
string shape_strings[] = {
|
||||
"f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
|
||||
"f32[123,456]dense{foo}", "f32[123,456]sparse{foo}",
|
||||
};
|
||||
for (const string& shape_string : shape_strings) {
|
||||
StatusOr<Shape> result = ParseShape(shape_string);
|
||||
ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1,78 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Defines different kinds of tokens in a hlo module string.
|
||||
//
|
||||
// You shouldn't need to use this directly unless you're using HloLexer
|
||||
// directly, and you probably don't need to do that. Use hlo_parser instead.
|
||||
enum class TokKind {
|
||||
// Markers
|
||||
kEof,
|
||||
kError,
|
||||
|
||||
// Tokens with no info.
|
||||
kEqual, // =
|
||||
kComma, // ,
|
||||
kColon, // :
|
||||
kLsquare,
|
||||
kRsquare, // [ ]
|
||||
kLbrace,
|
||||
kRbrace, // { }
|
||||
kLparen,
|
||||
kRparen, // ( )
|
||||
|
||||
kArrow, // ->
|
||||
|
||||
// Keywords
|
||||
kw_HloModule,
|
||||
kw_ENTRY,
|
||||
kw_ROOT,
|
||||
kw_true,
|
||||
kw_false,
|
||||
kw_maximal,
|
||||
kw_replicated,
|
||||
kw_nan,
|
||||
kw_inf,
|
||||
|
||||
kNegInf, // -inf
|
||||
|
||||
// Typed tokens.
|
||||
kName, // %foo
|
||||
kAttributeName, // dimensions=
|
||||
kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
|
||||
kDxD, // [0-9]+(x[0-9]+)+
|
||||
kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*
|
||||
kIdent, // other identifiers
|
||||
kString, // "abcd\"\n"
|
||||
kShape, // f32[2,3]{1,0}
|
||||
kInt, // 42
|
||||
kDecimal, // 4.2
|
||||
};
|
||||
|
||||
string TokKindToString(TokKind kind);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_
|
@ -99,7 +99,7 @@ TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) {
|
||||
HloModule SimpleGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
|
||||
operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}})
|
||||
indices = s32[5] parameter(0)
|
||||
ROOT gather = s32[5,3] gather(operand, indices),
|
||||
offset_dims={1},
|
||||
@ -119,7 +119,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) {
|
||||
HloModule SimpleGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
|
||||
operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}})
|
||||
indices = s32[5,2] parameter(0)
|
||||
ROOT gather = s32[5] gather(operand, indices),
|
||||
offset_dims={},
|
||||
@ -195,7 +195,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) {
|
||||
HloModule SimpleGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
|
||||
operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}})
|
||||
indices_a = s32[5] parameter(0)
|
||||
indices_b = s32[2] parameter(1)
|
||||
gather_a = s32[5,3] gather(operand, indices_a),
|
||||
@ -309,7 +309,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
|
||||
operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}})
|
||||
indices = s32[5] parameter(0)
|
||||
gather = s32[5,4] gather(operand, indices),
|
||||
offset_dims={1},
|
||||
@ -330,7 +330,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
|
||||
operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}})
|
||||
indices = s32[5,7] parameter(0)
|
||||
gather = s32[5,4,7] gather(operand, indices),
|
||||
offset_dims={1},
|
||||
@ -352,7 +352,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,2,6] constant(s32[3,2,6]{
|
||||
operand = s32[3,2,6] constant({
|
||||
{{1,2,3,4,5,6},{1,2,3,4,5,6}},
|
||||
{{1,2,3,4,5,6},{1,2,3,4,5,6}},
|
||||
{{1,2,3,4,5,6},{1,2,3,4,5,6}}})
|
||||
@ -377,7 +377,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[2,6] constant(s32[2,6]{
|
||||
operand = s32[2,6] constant({
|
||||
{1,2,3,4,5,6},{1,2,3,4,5,6}})
|
||||
indices = s32[1] parameter(0)
|
||||
gather = s32[1,6] gather(operand, indices),
|
||||
@ -405,7 +405,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } })
|
||||
operand = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 1, 2, 3 } })
|
||||
|
||||
i.0 = s64[1,3]{1,0} parameter(0)
|
||||
g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2},
|
||||
@ -438,7 +438,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}})
|
||||
operand = s32[1,6] constant({{1,2,3,4,5,6}})
|
||||
indices = s32[1] parameter(0)
|
||||
gather = s32[1,6] gather(operand, indices),
|
||||
offset_dims={1},
|
||||
@ -465,7 +465,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[1,2,6] constant(s32[1,2,6]{{
|
||||
operand = s32[1,2,6] constant({{
|
||||
{1,2,3,4,5,6},{1,2,3,4,5,6}}})
|
||||
indices = s32[1] parameter(0)
|
||||
gather = s32[1,1,6] gather(operand, indices),
|
||||
@ -496,7 +496,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[2,6] constant(s32[2,6]{
|
||||
operand = s32[2,6] constant({
|
||||
{1,2,3,4,5,6},{1,2,3,4,5,6}})
|
||||
indices = s32[1,5] parameter(0)
|
||||
gather = s32[1,5,6] gather(operand, indices),
|
||||
@ -527,7 +527,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
|
||||
operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}})
|
||||
indices = s32[5,6] parameter(0)
|
||||
gather = s32[5,4,6] gather(operand, indices),
|
||||
offset_dims={1},
|
||||
@ -556,7 +556,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,5,2] constant(s32[3,5,2]{
|
||||
operand = s32[3,5,2] constant({
|
||||
{{1,2},{3,4},{5,6},{7,8},{9,10}},
|
||||
{{1,2},{3,4},{5,6},{7,8},{9,10}},
|
||||
{{1,2},{3,4},{5,6},{7,8},{9,10}}})
|
||||
@ -588,7 +588,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) {
|
||||
HloModule ReshapeOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,4,1] constant(s32[3,4,1]{
|
||||
operand = s32[3,4,1] constant({
|
||||
{{1},{2},{3},{4}},
|
||||
{{1},{2},{3},{4}},
|
||||
{{1},{2},{3},{4}}})
|
||||
@ -620,7 +620,7 @@ TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) {
|
||||
HloModule UnaryOpOfGather
|
||||
|
||||
ENTRY main {
|
||||
operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
operand = f32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
indices = s32[5] parameter(0)
|
||||
gather = f32[5,4] gather(operand, indices),
|
||||
offset_dims={1},
|
||||
@ -645,7 +645,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) {
|
||||
HloModule AddBroadcastedScalarWithGather
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
constant = s32[] constant(5)
|
||||
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
|
||||
indices = s32[5] parameter(0)
|
||||
@ -673,7 +673,7 @@ TEST_F(IndexedArrayAnalysisTest,
|
||||
HloModule SubtractBroadcastedScalarWithGather
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
constant = s32[] constant(5)
|
||||
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
|
||||
indices = s32[5] parameter(0)
|
||||
@ -701,7 +701,7 @@ TEST_F(IndexedArrayAnalysisTest,
|
||||
HloModule SubtractBroadcastedScalarWithGather
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
constant = s32[] constant(5)
|
||||
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
|
||||
indices = s32[5] parameter(0)
|
||||
@ -728,7 +728,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) {
|
||||
HloModule AddBroadcastedVectorWithGather
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
constant_vect = s32[4] constant({10,11,12,13})
|
||||
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
|
||||
indices = s32[5] parameter(0)
|
||||
@ -755,7 +755,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) {
|
||||
HloModule AddBroadcastedVectorWithGather
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}})
|
||||
constant_vect = s32[5] constant({10,11,12,13,14})
|
||||
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
|
||||
indices = s32[5] parameter(0)
|
||||
@ -804,8 +804,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) {
|
||||
HloModule DotOp
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
|
||||
dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
|
||||
gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
|
||||
dot_rhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
|
||||
indices = s32[5] parameter(0)
|
||||
dot_lhs = s32[5,4] gather(gather_operand, indices),
|
||||
offset_dims={1},
|
||||
@ -831,8 +831,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) {
|
||||
HloModule DotOp
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
|
||||
dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}})
|
||||
gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
|
||||
dot_rhs_constant = s32[3,3] constant({{1,2,3},{4,5,6},{7,8,9}})
|
||||
indices = s32[5] parameter(0)
|
||||
dot_lhs = s32[3,5] gather(gather_operand, indices),
|
||||
offset_dims={0},
|
||||
@ -859,8 +859,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) {
|
||||
HloModule DotOp
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
|
||||
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
|
||||
gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
|
||||
dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
|
||||
indices = s32[5] parameter(0)
|
||||
dot_rhs = s32[3,5] gather(gather_operand, indices),
|
||||
offset_dims={0},
|
||||
@ -888,8 +888,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) {
|
||||
HloModule DotOp
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
|
||||
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
|
||||
gather_operand = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
|
||||
dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
|
||||
indices = s32[5] parameter(0)
|
||||
dot_rhs = s32[5,3] gather(gather_operand, indices),
|
||||
offset_dims={1},
|
||||
@ -917,8 +917,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) {
|
||||
HloModule DotOp
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}})
|
||||
dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
|
||||
gather_operand = s32[2,3,2] constant({{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}})
|
||||
dot_lhs_constant = s32[2,2,3] constant({{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
|
||||
indices = s32[4] parameter(0)
|
||||
dot_rhs = s32[2,3,4] gather(gather_operand, indices),
|
||||
offset_dims={0,1},
|
||||
@ -948,8 +948,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpNegative) {
|
||||
HloModule DotOp
|
||||
|
||||
ENTRY main {
|
||||
gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
|
||||
dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}})
|
||||
gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}})
|
||||
dot_rhs_constant = s32[2,3] constant({{1,2,3},{4,5,6}})
|
||||
indices = s32[2] parameter(0)
|
||||
dot_lhs = s32[3,2] gather(gather_operand, indices),
|
||||
offset_dims={0},
|
||||
|
@ -259,8 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
||||
add = f32[4,3]{1,0} add(p0, p0)
|
||||
abs1 = f32[4,3]{1,0} abs(add)
|
||||
log = f32[4,3]{1,0} log(abs1)
|
||||
token = token[] after-all()
|
||||
send = f32[4,3]{1,0} send(log, token), channel_id=0
|
||||
token0 = token[] after-all()
|
||||
send = f32[4,3]{1,0} send(log, token0), channel_id=0
|
||||
abs2 = f32[4,3]{1,0} abs(log)
|
||||
ROOT root = f32[4,3]{1,0} subtract(abs2, add)
|
||||
})")
|
||||
@ -290,8 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
||||
p0 = f32[4,3]{1,0} parameter(0)
|
||||
add1 = f32[4,3]{1,0} add(p0, p0)
|
||||
log = f32[4,3]{1,0} log(p0)
|
||||
token = token[] after-all()
|
||||
send = f32[4,3]{1,0} send(log, token), channel_id=0
|
||||
token0 = token[] after-all()
|
||||
send = f32[4,3]{1,0} send(log, token0), channel_id=0
|
||||
add2 = f32[4,3]{1,0} add(log, add1)
|
||||
ROOT root = f32[4,3]{1,0} subtract(add1, add2)
|
||||
})")
|
||||
@ -324,8 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
||||
add1 = f32[4,3]{1,0} add(p0, p0)
|
||||
add2 = f32[4,3]{1,0} add(add1, add1)
|
||||
log = f32[4,3]{1,0} log(add2)
|
||||
token = token[] after-all()
|
||||
send = f32[4,3]{1,0} send(log, token), channel_id=0
|
||||
token0 = token[] after-all()
|
||||
send = f32[4,3]{1,0} send(log, token0), channel_id=0
|
||||
sub1 = f32[4,3]{1,0} subtract(log, add2)
|
||||
sub2 = f32[4,3]{1,0} subtract(add2, add1)
|
||||
ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)
|
||||
|
@ -847,12 +847,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
|
||||
ENTRY entry_computation {
|
||||
param = (f32[2,2]) parameter(0)
|
||||
gte = f32[2,2] get-tuple-element(param), index=0
|
||||
token = token[] after-all()
|
||||
recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1}
|
||||
token0 = token[] after-all()
|
||||
recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
|
||||
recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
|
||||
sharding={maximal device=1}
|
||||
ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
|
||||
send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1,
|
||||
send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
|
||||
sharding={maximal device=0}
|
||||
send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
|
||||
}
|
||||
@ -897,7 +897,7 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
|
||||
ar.0 = f32[2,2] cross-replica-sum(gte),
|
||||
all_reduce_id=1, replica_groups={{0}}, to_apply=add,
|
||||
sharding={maximal device=0}
|
||||
const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}})
|
||||
const = f32[2,2] constant({{0,1},{2,3}})
|
||||
ROOT ar.1 = f32[2,2] cross-replica-sum(const),
|
||||
all_reduce_id=1, replica_groups={{0}}, to_apply=add,
|
||||
sharding={maximal device=1}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -42,6 +43,7 @@ NameUniquer::NameUniquer(const string& separator) {
|
||||
if (name.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
string result = name;
|
||||
char c = static_cast<unsigned char>(result[0]);
|
||||
if (!isalpha(c) && c != '_') {
|
||||
@ -52,6 +54,13 @@ NameUniquer::NameUniquer(const string& separator) {
|
||||
result[i] = '_';
|
||||
}
|
||||
}
|
||||
|
||||
// HLO primitive type names (with the exception of 'tuple') are keywords in
|
||||
// the HLO text representation and cannot be names, so append an underscore if
|
||||
// the name is a primitive type.
|
||||
if (primitive_util::IsPrimitiveTypeName(result) && result != "tuple") {
|
||||
result += "_";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -104,5 +104,21 @@ TEST_F(NameUniquerTest, KeepNamesInRandomOrder) {
|
||||
EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3"));
|
||||
}
|
||||
|
||||
TEST_F(NameUniquerTest, AvoidKeywords) {
|
||||
NameUniquer uniquer(".");
|
||||
|
||||
EXPECT_EQ("f32_", uniquer.GetUniqueName("f32"));
|
||||
EXPECT_EQ("s64_", uniquer.GetUniqueName("s64"));
|
||||
EXPECT_EQ("pred_", uniquer.GetUniqueName("pred"));
|
||||
|
||||
// Though a primitive type, "tuple" is not a keyword.
|
||||
EXPECT_EQ("tuple", uniquer.GetUniqueName("tuple"));
|
||||
|
||||
// Keywords are not capitalized.
|
||||
EXPECT_EQ("F32", uniquer.GetUniqueName("F32"));
|
||||
EXPECT_EQ("S32", uniquer.GetUniqueName("S32"));
|
||||
EXPECT_EQ("Pred", uniquer.GetUniqueName("Pred"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1737,7 +1737,8 @@ class HloConstantScalarImpl {
|
||||
literal_r0_as_val_ty_or.ValueOrDie() == val_literal &&
|
||||
literal_r0 == val_as_literal_ty;
|
||||
if (!rv) {
|
||||
EXPLAIN << "HloInstruction's constant value " << literal_r0.ToString()
|
||||
EXPLAIN << "HloInstruction's constant value "
|
||||
<< literal_r0.ToStringWithoutShape()
|
||||
<< " did not match expected value " << *val_;
|
||||
}
|
||||
return rv;
|
||||
|
@ -242,8 +242,8 @@ TEST(PatternMatcherTest, ConstantScalar) {
|
||||
HloModule test_module
|
||||
ENTRY test {
|
||||
a = s32[] constant(1)
|
||||
b = s32[1,1] constant(s32[1,1]{{2}})
|
||||
c = s32[1,2] constant(s32[1,2]{{2,2}})
|
||||
b = s32[1,1] constant({{2}})
|
||||
c = s32[1,2] constant({{2,2}})
|
||||
d = f32[] constant(1)
|
||||
e = f32[] constant(1.25)
|
||||
ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
|
||||
|
@ -139,9 +139,9 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
|
||||
HloModule FoldDotTransposeConstant
|
||||
|
||||
ENTRY entry_computation {
|
||||
constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } })
|
||||
constant = f32[2,1]{1,0} constant({ { 1 }, { 2 } })
|
||||
transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0}
|
||||
constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } })
|
||||
constant.1 = f32[3,2]{1,0} constant({ { 1, 2 }, { 3, 4 }, { 5, 6 } })
|
||||
transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0}
|
||||
ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ condition {
|
||||
|
||||
ENTRY entry {
|
||||
const_0 = f32[2] constant({1, 2})
|
||||
const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1}))
|
||||
const_1 = (f32[2], f32[2]) constant(({2, 1},{3,1}))
|
||||
while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1)
|
||||
ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body
|
||||
}
|
||||
@ -206,8 +206,8 @@ body {
|
||||
p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0
|
||||
p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1
|
||||
|
||||
token = token[] after-all()
|
||||
outfeed = token[] outfeed(p_body.0, token)
|
||||
token0 = token[] after-all()
|
||||
outfeed = token[] outfeed(p_body.0, token0)
|
||||
ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1)
|
||||
}
|
||||
|
||||
@ -305,7 +305,7 @@ condition {
|
||||
|
||||
ENTRY entry {
|
||||
const_0 = f32[] constant(0)
|
||||
const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10))
|
||||
const_1 = (f32[], f32[]) constant((1, 10))
|
||||
while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1)
|
||||
ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
@ -554,8 +555,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) {
|
||||
|
||||
HloInstruction* new_while = FindFirstWhile(m.get());
|
||||
Shape flat_tuple =
|
||||
ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])")
|
||||
.ValueOrDie();
|
||||
ParseShape("(s32[1], s32[2], s32[3], s32[4])").ValueOrDie();
|
||||
SCOPED_TRACE(m->ToString());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
@ -567,8 +567,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) {
|
||||
flat_tuple));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
m->entry_computation()->root_instruction()->shape(),
|
||||
ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))")
|
||||
.ValueOrDie()));
|
||||
ParseShape("((s32[1]), (s32[2], s32[3], (s32[4])))").ValueOrDie()));
|
||||
}
|
||||
|
||||
// Edge-case: All elements of the loop carry are constants which can be removed,
|
||||
@ -641,8 +640,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) {
|
||||
EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok());
|
||||
|
||||
HloInstruction* new_while = FindFirstWhile(m.get());
|
||||
Shape new_while_shape =
|
||||
ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie();
|
||||
Shape new_while_shape = ParseShape("(s32[1], s32[3])").ValueOrDie();
|
||||
EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
new_while->while_body()->root_instruction()->shape(), new_while_shape));
|
||||
@ -652,9 +650,9 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) {
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
new_while->while_condition()->parameter_instruction(0)->shape(),
|
||||
new_while_shape));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
m->entry_computation()->root_instruction()->shape(),
|
||||
ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie()));
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(m->entry_computation()->root_instruction()->shape(),
|
||||
ParseShape("(s32[1], s32[2], s32[3])").ValueOrDie()));
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
op::Tuple(_, op::Constant(), _));
|
||||
}
|
||||
@ -712,7 +710,7 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) {
|
||||
// We should have added a new loop counter for s32[] to the end of the tuple.
|
||||
SCOPED_TRACE(m->ToString());
|
||||
Shape new_while_shape =
|
||||
ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie();
|
||||
ParseShape("(s32[], s32[], s32[], s32[])").ValueOrDie();
|
||||
EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
new_while->while_body()->root_instruction()->shape(), new_while_shape));
|
||||
|
@ -180,8 +180,8 @@ body {
|
||||
|
||||
cond {
|
||||
param.c = (s32[], s32[]) parameter(0)
|
||||
token = token[] after-all()
|
||||
infeed = (pred[], token[]) infeed(token)
|
||||
token0 = token[] after-all()
|
||||
infeed = (pred[], token[]) infeed(token0)
|
||||
ROOT condition = pred[] get-tuple-element(infeed), index=0
|
||||
}
|
||||
|
||||
|
@ -234,7 +234,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
|
||||
PrimitiveType element_type, absl::Span<const int64> dimensions) {
|
||||
CHECK(IsArrayPrimitiveType(element_type));
|
||||
CHECK(IsArrayPrimitiveType(element_type)) << element_type;
|
||||
Shape result;
|
||||
TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result));
|
||||
return result;
|
||||
@ -480,54 +480,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return IsScalar(shape) && shape.element_type() == element_type;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Class to memoize the computation of
|
||||
// absl::AsciiStrToLower(PrimitiveType_Name(p))
|
||||
// for all PrimitiveType values "p"
|
||||
class PrimitiveTypeNameGenerator {
|
||||
public:
|
||||
PrimitiveTypeNameGenerator() {
|
||||
for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
|
||||
if (PrimitiveType_IsValid(i)) {
|
||||
lowercase_name_[i] = absl::AsciiStrToLower(
|
||||
PrimitiveType_Name(static_cast<PrimitiveType>(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
const string& LowercaseName(PrimitiveType t) {
|
||||
return lowercase_name_[static_cast<int>(t)];
|
||||
}
|
||||
|
||||
private:
|
||||
string lowercase_name_[PrimitiveType_ARRAYSIZE];
|
||||
};
|
||||
|
||||
const string& LowercasePrimitiveTypeName(PrimitiveType s) {
|
||||
static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator();
|
||||
return gen->LowercaseName(s);
|
||||
}
|
||||
|
||||
StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
|
||||
static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
|
||||
static auto* map = new std::unordered_map<string, PrimitiveType>;
|
||||
for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
|
||||
if (PrimitiveType_IsValid(i)) {
|
||||
auto value = static_cast<PrimitiveType>(i);
|
||||
(*map)[LowercasePrimitiveTypeName(value)] = value;
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}();
|
||||
auto found = name_to_type->find(name);
|
||||
if (found == name_to_type->end()) {
|
||||
return InvalidArgument("Invalid element type string: \"%s\".", name);
|
||||
}
|
||||
return found->second;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ string ShapeUtil::HumanString(const Shape& shape) {
|
||||
if (IsTuple(shape)) {
|
||||
string text = "(";
|
||||
@ -539,7 +491,8 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
|
||||
text += ")";
|
||||
return text;
|
||||
}
|
||||
return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[",
|
||||
return StrCat(
|
||||
primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
|
||||
absl::StrJoin(shape.dimensions(), ","), "]");
|
||||
}
|
||||
|
||||
@ -554,7 +507,8 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
|
||||
text += ")";
|
||||
return text;
|
||||
}
|
||||
string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[");
|
||||
string result = StrCat(
|
||||
primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[");
|
||||
for (int i = 0; i < shape.dimensions().size(); i++) {
|
||||
StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i));
|
||||
}
|
||||
@ -580,116 +534,6 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
|
||||
HumanString(program_shape.result()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Parses shapes with simple recursive descent structure -- consumes from the
|
||||
// front of s and passes that view recursively as required.
|
||||
StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
|
||||
*s = absl::StripLeadingAsciiWhitespace(*s);
|
||||
|
||||
if (absl::ConsumePrefix(s, "(")) { // Tuple.
|
||||
std::vector<Shape> shapes;
|
||||
bool must_end = false;
|
||||
while (true) {
|
||||
if (absl::ConsumePrefix(s, ")")) {
|
||||
break;
|
||||
} else if (must_end) {
|
||||
return InvalidArgument("Expected end of tuple; got: \"%s\"", *s);
|
||||
}
|
||||
shapes.emplace_back();
|
||||
TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s));
|
||||
*s = absl::StripLeadingAsciiWhitespace(*s);
|
||||
must_end = !absl::ConsumePrefix(s, ",");
|
||||
}
|
||||
return ShapeUtil::MakeTupleShape(shapes);
|
||||
}
|
||||
|
||||
string element_type_string;
|
||||
string dimensions_string;
|
||||
string format_string;
|
||||
string layout_string;
|
||||
// absl::string_view is not compatible with internal RE2 StringPiece, so
|
||||
// we convert in to the RE2-consumable type and then consume the corresponding
|
||||
// amount from our string_view type.
|
||||
static LazyRE2 shape_pattern = {
|
||||
"^(\\w*\\d*)\\[([\\d,\\s]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,\\s]+)})"
|
||||
"?"};
|
||||
tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
|
||||
if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string,
|
||||
&dimensions_string, &format_string, &layout_string)) {
|
||||
size_t consumed = s->size() - s_consumable.size();
|
||||
s->remove_prefix(consumed);
|
||||
auto string_to_int64 = [&s](absl::string_view input) -> StatusOr<int64> {
|
||||
int64 element;
|
||||
if (!absl::SimpleAtoi(input, &element)) {
|
||||
return InvalidArgument(
|
||||
"Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input,
|
||||
*s);
|
||||
}
|
||||
return element;
|
||||
};
|
||||
|
||||
auto comma_list_to_int64s =
|
||||
[string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
|
||||
std::vector<int64> results;
|
||||
for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) {
|
||||
TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
|
||||
results.push_back(element);
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
// Extract the dimensions.
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int64> dimensions,
|
||||
comma_list_to_int64s(dimensions_string));
|
||||
|
||||
// Extract the primitive element type.
|
||||
TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type,
|
||||
StringToPrimitiveType(element_type_string));
|
||||
if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) {
|
||||
return InvalidArgument("Invalid element type string: \"%s\".",
|
||||
element_type_string);
|
||||
}
|
||||
|
||||
Shape result;
|
||||
if (primitive_type == OPAQUE) {
|
||||
result = ShapeUtil::MakeOpaqueShape();
|
||||
} else if (primitive_type == TOKEN) {
|
||||
result = ShapeUtil::MakeTokenShape();
|
||||
} else if (format_string.empty() && layout_string.empty()) {
|
||||
// Create a shape without a layout set.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
result, ShapeUtil::MakeValidatedShape(primitive_type, dimensions));
|
||||
} else if (format_string == "sparse") {
|
||||
TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string));
|
||||
result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions,
|
||||
max_elements);
|
||||
} else if (format_string.empty() || format_string == "dense") {
|
||||
// Extract the layout minor-to-major and set it.
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
|
||||
comma_list_to_int64s(layout_string));
|
||||
TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal(
|
||||
primitive_type, dimensions, min2maj));
|
||||
} else {
|
||||
// This should not be reached.
|
||||
LOG(FATAL) << "Unhandled condition when parsing shape; format: \""
|
||||
<< format_string << "\", layout: \"" << layout_string << "\"";
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result));
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
return InvalidArgument("Invalid shape string to parse: \"%s\"", *s);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(absl::string_view s) {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s));
|
||||
if (!s.empty()) {
|
||||
return InvalidArgument("Invalid shape string to parse: \"%s\"", s);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
CHECK(ShapeUtil::IsArray(lhs));
|
||||
@ -867,13 +711,13 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
|
||||
if (shape.dimensions_size() != 0) {
|
||||
return InvalidArgument(
|
||||
"shape has %s element type, but has dimensions field: %s",
|
||||
LowercasePrimitiveTypeName(shape.element_type()),
|
||||
primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
|
||||
shape.ShortDebugString());
|
||||
}
|
||||
if (shape.has_layout()) {
|
||||
return InvalidArgument(
|
||||
"shape has %s element type, but has layout field: %s",
|
||||
LowercasePrimitiveTypeName(shape.element_type()),
|
||||
primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
|
||||
shape.ShortDebugString());
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -241,10 +241,6 @@ class ShapeUtil {
|
||||
// (param_name: f32[42x12], ...) -> f32[24x42]
|
||||
static string HumanString(const ProgramShape& program_shape);
|
||||
|
||||
// Parses a ShapeUtil::HumanString-format shape string back into a shape
|
||||
// object.
|
||||
static StatusOr<Shape> ParseShapeString(absl::string_view s);
|
||||
|
||||
// Returns whether the LHS and RHS shapes have the same dimensions; note: does
|
||||
// not check element type.
|
||||
// Precondition: IsArray(lhs) && IsArray(rhs)
|
||||
|
@ -82,102 +82,6 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) {
|
||||
ASSERT_EQ(3, shape.dimensions(0));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringR2F32) {
|
||||
string shape_string = "f32[123,456]";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
|
||||
ShapeUtil::ParseShapeString(shape_string));
|
||||
Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
|
||||
string shape_string = "(f32[1572864],s8[5120,1024])";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
|
||||
ShapeUtil::ParseShapeString(shape_string));
|
||||
Shape expected =
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
|
||||
ShapeUtil::MakeShape(S8, {5120, 1024})});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
|
||||
string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
|
||||
ShapeUtil::ParseShapeString(shape_string));
|
||||
Shape expected = ShapeUtil::MakeTupleShape({
|
||||
ShapeUtil::MakeShape(F32, {1}),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
|
||||
ShapeUtil::MakeOpaqueShape(),
|
||||
ShapeUtil::MakeShape(F32, {3}),
|
||||
});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringWithLayout) {
|
||||
string shape_string = "f32[123,456]{0,1}";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
|
||||
ShapeUtil::ParseShapeString(shape_string));
|
||||
Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) {
|
||||
string shape_string = "f32[123,456]dense{0,1}";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
|
||||
ShapeUtil::ParseShapeString(shape_string));
|
||||
Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) {
|
||||
string shape_string = "f32[123,456]sparse{10}";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
|
||||
ShapeUtil::ParseShapeString(shape_string));
|
||||
Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10);
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseOpaqueType) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
|
||||
ShapeUtil::ParseShapeString("opaque[]"));
|
||||
Shape expected = ShapeUtil::MakeOpaqueShape();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseTokenType) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]"));
|
||||
Shape expected = ShapeUtil::MakeTokenShape();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ParseInvalidShapeString) {
|
||||
string shape_strings[] = {
|
||||
"f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
|
||||
"f32[123,456]dense{foo}", "f32[123,456]sparse{foo}",
|
||||
};
|
||||
for (const string& shape_string : shape_strings) {
|
||||
StatusOr<Shape> result = ShapeUtil::ParseShapeString(shape_string);
|
||||
ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
|
||||
Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||
|
@ -89,11 +89,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
|
||||
Literal literal =
|
||||
Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
|
||||
if (result.find("expected") != string::npos) {
|
||||
EXPECT_EQ("2", literal.ToString());
|
||||
EXPECT_EQ("f32[] 2", literal.ToString());
|
||||
} else if (result.find("actual") != string::npos) {
|
||||
EXPECT_EQ("4", literal.ToString());
|
||||
EXPECT_EQ("f32[] 4", literal.ToString());
|
||||
} else if (result.find("mismatches") != string::npos) {
|
||||
EXPECT_EQ("true", literal.ToString());
|
||||
EXPECT_EQ("pred[] true", literal.ToString());
|
||||
} else {
|
||||
FAIL() << "unknown file in temporary directory: " << result;
|
||||
}
|
||||
@ -105,9 +105,9 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
|
||||
auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
|
||||
::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
|
||||
EXPECT_THAT(result.message(),
|
||||
::testing::HasSubstr("Expected literal:\n{1, 2, 3}"));
|
||||
::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}"));
|
||||
EXPECT_THAT(result.message(),
|
||||
::testing::HasSubstr("Actual literal:\n{4, 5, 6}"));
|
||||
::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}"));
|
||||
}
|
||||
|
||||
TEST(LiteralTestUtilTest, NearComparatorR1) {
|
||||
|
@ -61,11 +61,11 @@ XLA_TEST_F(TestUtilsTest, Token) {
|
||||
R"(HloModule outfeed_module
|
||||
|
||||
ENTRY InfeedToOutfeed {
|
||||
token = token[] parameter(0)
|
||||
infeed = ((u32[3]{0}, pred[]), token[]) infeed(token)
|
||||
token0 = token[] parameter(0)
|
||||
infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
|
||||
infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
|
||||
outfeed = token[] outfeed(infeed.data, token)
|
||||
ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token)
|
||||
outfeed = token[] outfeed(infeed.data, token0)
|
||||
ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0)
|
||||
infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
|
||||
infeed.1.token = token[] get-tuple-element(infeed.1), index=1
|
||||
outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
|
||||
|
@ -214,8 +214,8 @@ ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] {
|
||||
|
||||
%forty_two = f32[] constant(42.0)
|
||||
%add = f32[] add(f32[] %p0, f32[] %forty_two)
|
||||
%token = token[] after-all(f32[] %add)
|
||||
%p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token)
|
||||
%token0 = token[] after-all(f32[] %add)
|
||||
%p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token0)
|
||||
%neg = f32[] negate(f32[] %p1_after_token)
|
||||
ROOT %product = f32[] multiply(f32[] %add, f32[] %neg)
|
||||
}
|
||||
@ -236,8 +236,8 @@ HloModule AddDependencyOfConstant, is_scheduled=true
|
||||
ENTRY %AddDependency (p0: f32[]) -> f32[] {
|
||||
%p0 = f32[] parameter(0)
|
||||
%forty_two = f32[] constant(42.0)
|
||||
%token = token[] after-all(f32[] %p0)
|
||||
%forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token)
|
||||
%token0 = token[] after-all(f32[] %p0)
|
||||
%forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token0)
|
||||
ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token)
|
||||
}
|
||||
)";
|
||||
@ -255,8 +255,8 @@ HloModule AddDependencyAsRoot, is_scheduled=true
|
||||
ENTRY %AddDependency (p: f32[3]) -> f32[3] {
|
||||
%p = f32[3] parameter(0)
|
||||
%neg = f32[3] negate(f32[3] %p)
|
||||
%token = token[] after-all()
|
||||
ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token)
|
||||
%token0 = token[] after-all()
|
||||
ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token0)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
@ -274,9 +274,9 @@ ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] {
|
||||
%p0 = f32[3] parameter(0)
|
||||
%p1 = f32[3] parameter(1)
|
||||
%forty_two = f32[] constant(42.0)
|
||||
%token = token[] after-all()
|
||||
%tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two)
|
||||
%add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token)
|
||||
%token0 = token[] after-all()
|
||||
%tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token0, f32[3] %p1, f32[] %forty_two)
|
||||
%add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token0)
|
||||
%elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0
|
||||
%elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2
|
||||
ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2)
|
||||
|
@ -555,8 +555,8 @@ XLA_TEST_F(TupleHloTest,
|
||||
s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1)
|
||||
gte = f32[2] get-tuple-element(s), index=0
|
||||
tuple = (f32[2]) tuple(gte)
|
||||
token = token[] after-all()
|
||||
ROOT outfeed = token[] outfeed(tuple, token)
|
||||
token0 = token[] after-all()
|
||||
ROOT outfeed = token[] outfeed(tuple, token0)
|
||||
}
|
||||
)";
|
||||
auto module =
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -66,7 +67,7 @@ StatusOr<Literal> TextLiteralReader::ReadAllLines() {
|
||||
}
|
||||
|
||||
absl::StripAsciiWhitespace(&shape_string);
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ParseShape(shape_string));
|
||||
if (shape.element_type() != F32) {
|
||||
return Unimplemented(
|
||||
"unsupported element type for text literal reading: %s",
|
||||
|
@ -145,8 +145,7 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
||||
bool provide_infeed = false;
|
||||
Shape infeed_shape;
|
||||
if (!opts.fake_infeed_shape.empty()) {
|
||||
StatusOr<Shape> shape_status =
|
||||
ShapeUtil::ParseShapeString(opts.fake_infeed_shape);
|
||||
StatusOr<Shape> shape_status = ParseShape(opts.fake_infeed_shape);
|
||||
TF_CHECK_OK(shape_status.status());
|
||||
infeed_shape = std::move(shape_status).ValueOrDie();
|
||||
provide_infeed = true;
|
||||
|
Loading…
Reference in New Issue
Block a user