fixes #37480 PiperOrigin-RevId: 302612383 Change-Id: If6cdb1149b08909d878b7808db26e64179a1a7e2
270 lines
9.9 KiB
C++
270 lines
9.9 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/aot/codegen.h"
|
|
|
|
#include <algorithm>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "absl/strings/match.h"
|
|
#include "absl/strings/string_view.h"
|
|
#include "llvm/Support/TargetSelect.h"
|
|
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
|
#include "tensorflow/core/framework/types.pb.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
#include "tensorflow/core/lib/io/path.h"
|
|
#include "tensorflow/core/platform/env.h"
|
|
#include "tensorflow/core/platform/resource_loader.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
|
|
namespace tensorflow {
|
|
namespace tfcompile {
|
|
namespace {
|
|
|
|
using ::xla::cpu_function_runtime::BufferInfo;
|
|
|
|
void ExpectErrorContains(const Status& status, absl::string_view str) {
|
|
EXPECT_NE(Status::OK(), status);
|
|
EXPECT_TRUE(absl::StrContains(status.error_message(), str))
|
|
<< "expected error: " << status.error_message() << " to contain: " << str;
|
|
}
|
|
|
|
TEST(ValidateCppIdent, Simple) {
|
|
TF_EXPECT_OK(ValidateCppIdent("a", ""));
|
|
TF_EXPECT_OK(ValidateCppIdent("abc", ""));
|
|
TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
|
|
TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
|
|
// Make sure we didn't skip a valid letter or digit
|
|
string ident;
|
|
for (char c = 'a'; c <= 'z'; c++) {
|
|
ident.append(1, c);
|
|
}
|
|
for (char c = 'A'; c <= 'Z'; c++) {
|
|
ident.append(1, c);
|
|
}
|
|
for (char c = '0'; c <= '9'; c++) {
|
|
ident.append(1, c);
|
|
}
|
|
ident += "_";
|
|
TF_EXPECT_OK(ValidateCppIdent(ident, ""));
|
|
|
|
ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
|
|
ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
|
|
ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
|
|
ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
|
|
ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
|
|
ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
|
|
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
|
|
ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
|
|
}
|
|
|
|
class ParseCppClassTest : public ::testing::Test {
|
|
protected:
|
|
void ExpectOK(const string& cpp_class, const string& want_class_name,
|
|
const std::vector<string>& want_namespaces) {
|
|
string class_name;
|
|
std::vector<string> namespaces;
|
|
TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces));
|
|
EXPECT_EQ(class_name, want_class_name);
|
|
EXPECT_EQ(namespaces, want_namespaces);
|
|
}
|
|
|
|
void ExpectFail(const string& cpp_class) {
|
|
string class_name;
|
|
std::vector<string> namespaces;
|
|
EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), Status::OK())
|
|
<< cpp_class;
|
|
}
|
|
};
|
|
|
|
TEST_F(ParseCppClassTest, ParseOK) {
|
|
ExpectOK("MyClass", "MyClass", {});
|
|
ExpectOK("_MyClass", "_MyClass", {});
|
|
ExpectOK("a::MyClass", "MyClass", {"a"});
|
|
ExpectOK("a::foo::MyClass", "MyClass", {"a", "foo"});
|
|
ExpectOK("a::foo::b::MyClass", "MyClass", {"a", "foo", "b"});
|
|
ExpectOK("a::foo::b::bar::MyClass", "MyClass", {"a", "foo", "b", "bar"});
|
|
ExpectOK("foo::MyClass", "MyClass", {"foo"});
|
|
ExpectOK("_foo::MyClass", "MyClass", {"_foo"});
|
|
ExpectOK("_foo::_MyClass", "_MyClass", {"_foo"});
|
|
ExpectOK("::foo::bar::MyClass", "MyClass", {"foo", "bar"});
|
|
ExpectOK("::_foo::MyClass", "MyClass", {"_foo"});
|
|
ExpectOK("::_foo::_MyClass", "_MyClass", {"_foo"});
|
|
// Make sure we didn't skip a valid letter or digit
|
|
string ident;
|
|
for (char c = 'a'; c <= 'z'; c++) {
|
|
ident.append(1, c);
|
|
}
|
|
for (char c = 'A'; c <= 'Z'; c++) {
|
|
ident.append(1, c);
|
|
}
|
|
for (char c = '0'; c <= '9'; c++) {
|
|
ident.append(1, c);
|
|
}
|
|
ident += "_";
|
|
ExpectOK(ident, ident, {});
|
|
ExpectOK(ident + "::" + ident, ident, {ident});
|
|
ExpectOK(ident + "::" + ident + "::" + ident, ident, {ident, ident});
|
|
}
|
|
|
|
TEST_F(ParseCppClassTest, ParseFail) {
|
|
ExpectFail("");
|
|
ExpectFail("::");
|
|
ExpectFail("0");
|
|
ExpectFail("a.b");
|
|
ExpectFail("a:b");
|
|
ExpectFail(":foo::bar");
|
|
ExpectFail("good::.bad");
|
|
ExpectFail("good:::bad");
|
|
ExpectFail("good::bad::");
|
|
ExpectFail("good::::bad");
|
|
ExpectFail("::::bad");
|
|
ExpectFail("good:: bad");
|
|
ExpectFail("good::0bad");
|
|
}
|
|
|
|
static void CompareWithGoldenFile(
|
|
const string& tensorflow_relative_golden_file_name,
|
|
const string& expected_contents, bool ignore_cr) {
|
|
// Get rid of all CR characters, we may be running under windows.
|
|
string sanitized_expected_contents(expected_contents);
|
|
if (ignore_cr) {
|
|
sanitized_expected_contents.erase(
|
|
std::remove(sanitized_expected_contents.begin(),
|
|
sanitized_expected_contents.end(), '\r'),
|
|
sanitized_expected_contents.end());
|
|
}
|
|
|
|
// To update the golden file, flip update_golden to true and run the
|
|
// following:
|
|
// bazel test --test_strategy=local \
|
|
// "third_party/tensorflow/compiler/aot:codegen_test"
|
|
const bool update_golden = false;
|
|
string golden_file_name =
|
|
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
|
|
|
if (update_golden) {
|
|
TF_EXPECT_OK(
|
|
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
|
|
}
|
|
|
|
string golden_file_contents;
|
|
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
|
&golden_file_contents));
|
|
if (ignore_cr) {
|
|
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
|
|
golden_file_contents.end(), '\r'),
|
|
golden_file_contents.end());
|
|
}
|
|
EXPECT_EQ(golden_file_contents, expected_contents);
|
|
}
|
|
|
|
TEST(CodegenTest, Golden) {
|
|
// Normally CpuCompiler::CpuCompiler does this, but in this test we've
|
|
// bypassed the Cpu compiler so we have to do this manually.
|
|
LLVMInitializeX86Target();
|
|
LLVMInitializeX86TargetInfo();
|
|
LLVMInitializeX86TargetMC();
|
|
LLVMInitializeX86AsmPrinter();
|
|
|
|
CodegenOpts opts;
|
|
opts.class_name = "MyClass";
|
|
opts.target_triple = "x86_64-pc-linux";
|
|
opts.namespaces = {"foo", "bar"};
|
|
opts.gen_name_to_index = true;
|
|
opts.gen_program_shape = true;
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("feed0");
|
|
feed->set_name("myfeed");
|
|
feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("feed1");
|
|
tf2xla::Fetch* fetch = config.add_fetch();
|
|
fetch->mutable_id()->set_node_name("fetch0");
|
|
fetch->set_name("myfetch");
|
|
tf2xla::Variable* variable = config.add_variable();
|
|
variable->set_node_name("myvar_readonly");
|
|
variable->mutable_shape()->add_dim()->set_size(1);
|
|
variable->set_type(DT_FLOAT);
|
|
variable->set_readonly(true);
|
|
tf2xla::Variable* variable2 = config.add_variable();
|
|
variable2->set_node_name("myvar");
|
|
variable2->mutable_shape()->add_dim()->set_size(1);
|
|
variable2->set_type(DT_FLOAT);
|
|
tf2xla::Variable* variable3 = config.add_variable();
|
|
variable3->set_node_name("my/var");
|
|
variable3->set_name("myvar2");
|
|
variable3->mutable_shape()->add_dim()->set_size(5);
|
|
variable3->set_type(DT_INT32);
|
|
CompileResult compile_result;
|
|
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
|
{},
|
|
{BufferInfo::MakeTempBuffer(1),
|
|
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
|
|
BufferInfo::MakeTempBuffer(1),
|
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
|
BufferInfo::MakeTempBuffer(1),
|
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
|
|
BufferInfo::MakeTempBuffer(1),
|
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
|
|
BufferInfo::MakeTempBuffer(1),
|
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
|
|
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
|
|
11, {}));
|
|
compile_result.program_shape =
|
|
xla::ShapeUtil::MakeProgramShape(
|
|
{
|
|
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
|
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
|
xla::ShapeUtil::MakeShape(xla::F32, {1}),
|
|
xla::ShapeUtil::MakeShape(xla::F32, {1}),
|
|
xla::ShapeUtil::MakeShape(xla::S32, {5}),
|
|
},
|
|
xla::ShapeUtil::MakeTupleShape({
|
|
xla::ShapeUtil::MakeShape(xla::U32, {5, 6}),
|
|
xla::ShapeUtil::MakeShape(xla::F32, {1}),
|
|
xla::ShapeUtil::MakeShape(xla::S32, {5}),
|
|
}))
|
|
.ToProto();
|
|
compile_result.entry_point = "entry_point";
|
|
compile_result.pointer_size = 8;
|
|
|
|
MetadataResult metadata_result;
|
|
TF_ASSERT_OK(GenerateMetadata(opts, compile_result, &metadata_result));
|
|
|
|
// The other fields in metadata_result are tested as part of the generated
|
|
// header test.
|
|
|
|
// This specific golden test checks a binary file. It can potentially run into
|
|
// issues due to ABIs not being stable, but has not so far.
|
|
// If we see any ABI issues, we should reconsider this specific test case.
|
|
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
|
|
metadata_result.object_file_data, false);
|
|
|
|
string header;
|
|
TF_ASSERT_OK(
|
|
GenerateHeader(opts, config, compile_result, metadata_result, &header));
|
|
|
|
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
|
|
true);
|
|
}
|
|
} // namespace
|
|
} // namespace tfcompile
|
|
} // namespace tensorflow
|