Merge pull request #14804 from yifeif/branch_176676125

Branch 176676125
This commit is contained in:
Yifei Feng 2017-11-22 14:49:05 -08:00 committed by GitHub
commit 79422ab39b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
459 changed files with 18645 additions and 7898 deletions

View File

@ -645,6 +645,7 @@ filegroup(
"//tensorflow/tools/test:all_files",
"//tensorflow/user_ops:all_files",
"//third_party/hadoop:all_files",
"//third_party/mpi:all_files",
"//third_party/sycl:all_files",
"//third_party/sycl/sycl:all_files",
],

View File

@ -939,13 +939,17 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
return;
}
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
dim_vec.reserve(num_dims);
for (int i = 0; i < num_dims; ++i) {
dim_vec.push_back(ic->MakeDim(dims[i]));
tensorflow::shape_inference::ShapeHandle new_shape;
if (num_dims != -1) {
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
dim_vec.reserve(num_dims);
for (int i = 0; i < num_dims; ++i) {
dim_vec.push_back(ic->MakeDim(dims[i]));
}
new_shape = ic->MakeShape(dim_vec);
} else {
new_shape = ic->UnknownShape();
}
tensorflow::shape_inference::ShapeHandle new_shape = ic->MakeShape(dim_vec);
status->status = graph->refiner.SetShape(node, output.index, new_shape);
}

View File

@ -287,6 +287,13 @@ TEST(CAPI, SetShape) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(-1, num_dims);
// Set the shape to be unknown, expect no change.
TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(-1, num_dims);
// Set the shape to be 2 x Unknown
int64_t dims[] = {2, -1};
TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
@ -315,7 +322,17 @@ TEST(CAPI, SetShape) {
EXPECT_EQ(dims[0], returned_dims[0]);
EXPECT_EQ(dims[1], returned_dims[1]);
// Try to set 'unknown' on the shape and see that
// Try to set 'unknown' with unknown rank on the shape and see that
// it doesn't change.
TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, returned_dims[0]);
EXPECT_EQ(3, returned_dims[1]);
// Try to set 'unknown' with same rank on the shape and see that
// it doesn't change.
dims[0] = -1;
dims[1] = -1;

View File

@ -571,6 +571,12 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
status->status = ctx->func_lib_def.AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
tensorflow::mutex_lock l(ctx->functions_mu);
status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
}
} // extern "C"
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {

View File

@ -200,6 +200,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
const char* serialized_function_def,
size_t size, TF_Status* status);
// Adds a function (created from TF_GraphToFunction or
// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with
// TFE_Execute by creating an op with the same name as the function.
TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx,
TF_Function* function,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -295,6 +295,66 @@ TEST(CAPI, Execute) {
TF_DeleteStatus(status);
}
TEST(CAPI, Function) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();
TF_OperationDescription* arg_descr =
TF_NewOperation(function_graph, "Placeholder", "arg");
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
TF_Status* status = TF_NewStatus();
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_OperationDescription* id_descr =
TF_NewOperation(function_graph, "Identity", "id");
TF_SetAttrType(id_descr, "T", TF_INT32);
TF_AddInput(id_descr, {arg, 0});
TF_Operation* id = TF_FinishOperation(id_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_Output input{arg, 0};
TF_Output output{id, 0};
TF_Function* fn =
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
&output, nullptr, nullptr, "test", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteGraph(function_graph);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextAddFunction(ctx, fn, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteFunction(fn);
TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, 1);
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteTensor(t);
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_OpAddInput(op, h, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
std::vector<TFE_TensorHandle*> result;
result.push_back(nullptr);
int num_retvals = 1;
TFE_Execute(op, result.data(), &num_retvals, status);
TFE_DeleteOp(op);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
ASSERT_EQ(num_retvals, 1);
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
TFE_DeleteTensorHandle(h);
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
TFE_DeleteContext(ctx, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
}
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(

View File

@ -485,7 +485,6 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
Status s = vspace.CallBackwardFunction(trace.backward_function,
out_gradients, &in_gradients);
if (!s.ok()) {
VLOG(1) << "Gradient function failed.";
cleanup();
return s;
}

View File

@ -421,6 +421,7 @@ tf_cc_test(
tf_gen_op_wrappers_cc(
name = "cc_ops",
api_def_srcs = ["//tensorflow/core:base_api_def"],
op_lib_names = [
"array_ops",
"audio_ops",
@ -525,6 +526,9 @@ cc_library_with_android_deps(
"//tensorflow/core:android_tensorflow_lib",
],
copts = tf_copts(),
data = [
"//tensorflow/core:base_api_def",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -536,6 +540,29 @@ cc_library_with_android_deps(
],
)
tf_cc_test(
name = "cc_op_gen_test",
srcs = [
"framework/cc_op_gen.cc",
"framework/cc_op_gen.h",
"framework/cc_op_gen_test.cc",
],
data = [
"//tensorflow/cc:ops/op_gen_overrides.pbtxt",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:op_gen_overrides_proto_cc",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "test_op_op_lib",
srcs = ["framework/test_op.cc"],

View File

@ -18,10 +18,11 @@ limitations under the License.
#include <vector>
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb_text.h"
@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
const int kRightMargin = 79;
@ -297,7 +297,7 @@ string ToCamelCase(const string& str) {
// argument to a function.
std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
static const std::unordered_map<StringPiece, std::pair<const char*, bool>,
StringPiece::Hasher>
StringPieceHasher>
attr_type_map{
{"string", {"StringPiece", false}},
{"list(string)", {"gtl::ArraySlice<string>", true}},
@ -325,29 +325,112 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
}
bool IsCPPKeyword(StringPiece name) {
static const std::unordered_set<StringPiece, StringPiece::Hasher>
static const std::unordered_set<StringPiece, StringPieceHasher>
// Keywords obtained from http://en.cppreference.com/w/cpp/keyword
kCPPReserved{
"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel",
"atomic_commit", "atomic_noexcept", "auto", "bitand", "bitor", "bool",
"break", "case", "catch", "char", "char16_t", "char32_t", "class",
"compl", "concept", "const", "const_cast", "constexpr", "continue",
"decltype", "default", "delete", "do", "double", "dynamic_cast",
"else", "enum", "explicit", "export", "extern", "false", "final",
"float", "for", "friend", "goto", "if", "import", "inline", "int",
"long", "module", "mutable", "namespace", "new", "noexcept", "not",
"not_eq", "nullptr", "operator", "or", "or_eq", "override", "private",
"protected", "public", "register", "reinterpret_cast", "requires",
"return", "short", "signed", "sizeof", "static", "static_assert",
"static_cast", "struct", "switch", "synchronized", "template", "this",
"thread_local", "throw", "true", "try", "typedef", "typeid",
"typename", "union", "unsigned", "using", "virtual", "void",
"volatile", "wchar_t", "while", "xor", "xor_eq",
"alignas",
"alignof",
"and",
"and_eq",
"asm",
"atomic_cancel",
"atomic_commit",
"atomic_noexcept",
"auto",
"bitand",
"bitor",
"bool",
"break",
"case",
"catch",
"char",
"char16_t",
"char32_t",
"class",
"compl",
"concept",
"const",
"const_cast",
"constexpr",
"continue",
"decltype",
"default",
"delete",
"do",
"double",
"dynamic_cast",
"else",
"enum",
"explicit",
"export",
"extern",
"false",
"final",
"float",
"for",
"friend",
"goto",
"if",
"import",
"inline",
"int",
"long",
"module",
"mutable",
"namespace",
"new",
"noexcept",
"not",
"not_eq",
"nullptr",
"operator",
"or",
"or_eq",
"override",
"private",
"protected",
"public",
"register",
"reinterpret_cast",
"requires",
"return",
"short",
"signed",
"sizeof",
"static",
"static_assert",
"static_cast",
"struct",
"switch",
"synchronized",
"template",
"this",
"thread_local",
"throw",
"true",
"try",
"typedef",
"typeid",
"typename",
"union",
"unsigned",
"using",
"virtual",
"void",
"volatile",
"wchar_t",
"while",
"xor",
"xor_eq",
// The following are not C++ keywords, but names of local variables
// and parameters used in the op constructor. Treating them as
// keywords, so that other parameter names don't conflict with these.
"builder", "node", "ret", "scope", "unique_name",
"builder",
"node",
"ret",
"scope",
"unique_name",
};
return kCPPReserved.count(name) > 0;
}
@ -385,10 +468,10 @@ bool ArgIsList(const OpDef::ArgDef& arg) {
}
bool HasOptionalAttrs(
const OpDef& op_def,
const ApiDef& api_def,
const std::unordered_map<string, string>& inferred_input_attrs) {
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
for (int i = 0; i < api_def.attr_size(); ++i) {
const auto& attr(api_def.attr(i));
if ((inferred_input_attrs.find(attr.name()) ==
inferred_input_attrs.end()) &&
attr.has_default_value()) {
@ -398,12 +481,21 @@ bool HasOptionalAttrs(
return false;
}
const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
for (int i = 0; i < api_def.in_arg_size(); ++i) {
if (api_def.in_arg(i).name() == name) {
return &api_def.in_arg(i);
}
}
return nullptr;
}
struct OpInfo {
// graph_op_def: The OpDef used by the runtime, has the names that
// must be used when calling NodeBuilder.
// interface_op_def: The OpDef used in the interface in the generated
// code, with possibly overridden names and defaults.
explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def,
explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases);
string GetOpAttrStruct() const;
string GetConstructorDecl(StringPiece op_name_prefix,
@ -423,74 +515,81 @@ struct OpInfo {
string comment;
const OpDef& graph_op_def;
const OpDef& op_def;
const ApiDef& api_def;
const std::vector<string>& aliases;
// Map from type attribute to corresponding original argument name.
std::unordered_map<string, string> inferred_input_attrs;
};
OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
const std::vector<string>& a)
: graph_op_def(g_op_def), op_def(i_op_def), aliases(a) {
op_name = op_def.name();
InferOpAttributes(op_def, &inferred_input_attrs);
has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs);
OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases)
: graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
op_name = api_def.endpoint(0).name();
InferOpAttributes(graph_op_def, &inferred_input_attrs);
has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
arg_types.push_back("const ::tensorflow::Scope&");
arg_names.push_back("scope");
if (op_def.has_deprecation()) {
if (!op_def.summary().empty()) {
comment = strings::StrCat(op_def.summary(), "\n");
if (graph_op_def.has_deprecation()) {
if (!api_def.summary().empty()) {
comment = strings::StrCat(api_def.summary(), "\n");
}
strings::StrAppend(&comment, "DEPRECATED at GraphDef version ",
op_def.deprecation().version(), ":\n",
op_def.deprecation().explanation(), ".\n");
} else if (op_def.summary().empty()) {
graph_op_def.deprecation().version(), ":\n",
graph_op_def.deprecation().explanation(), ".\n");
} else if (api_def.summary().empty()) {
comment = "TODO: add doc.\n";
} else {
comment = strings::StrCat(op_def.summary(), "\n");
comment = strings::StrCat(api_def.summary(), "\n");
}
if (!op_def.description().empty()) {
strings::StrAppend(&comment, "\n", op_def.description(), "\n");
if (!api_def.description().empty()) {
strings::StrAppend(&comment, "\n", api_def.description(), "\n");
}
strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
// Process inputs
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
for (int i = 0; i < api_def.arg_order_size(); ++i) {
const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def);
const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def);
arg_types.push_back(strings::StrCat(
"::tensorflow::", ArgIsList(arg) ? "InputList" : "Input"));
arg_names.push_back(AvoidCPPKeywords(arg.name()));
arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
// TODO(keveman): Include input type information.
StringPiece description = arg.description();
StringPiece description = api_def_arg.description();
if (!description.empty()) {
ConsumeEquals(&description);
strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ",
arg.description(), "\n");
strings::StrAppend(&comment, "* ",
AvoidCPPKeywords(api_def_arg.rename_to()), ": ",
api_def_arg.description(), "\n");
}
}
// Process attrs
string required_attrs_comment;
string optional_attrs_comment;
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
// ApiDef attributes must be in the same order as in OpDef since
// we initialize ApiDef based on OpDef.
const auto& attr(graph_op_def.attr(i));
const auto& api_def_attr(api_def.attr(i));
CHECK_EQ(attr.name(), api_def_attr.name());
// Skip inferred arguments
if (inferred_input_attrs.count(attr.name()) > 0) continue;
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
string attr_name = AvoidCPPKeywords(attr.name());
string attr_name = AvoidCPPKeywords(api_def_attr.rename_to());
string attr_comment;
if (!attr.description().empty()) {
if (!api_def_attr.description().empty()) {
// TODO(keveman): Word wrap and indent this, to handle multi-line
// descriptions.
strings::StrAppend(&attr_comment, "* ", attr_name, ": ",
attr.description(), "\n");
api_def_attr.description(), "\n");
}
if (attr.has_default_value()) {
if (api_def_attr.has_default_value()) {
strings::StrAppend(&optional_attrs_comment, attr_comment);
} else {
strings::StrAppend(&required_attrs_comment, attr_comment);
@ -508,44 +607,49 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
}
// Process outputs
for (int i = 0; i < op_def.output_arg_size(); ++i) {
const auto& arg = op_def.output_arg(i);
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
// ApiDef arguments must be in the same order as in OpDef since
// we initialize ApiDef based on OpDef.
const auto& arg = graph_op_def.output_arg(i);
const auto& api_def_arg(api_def.out_arg(i));
CHECK_EQ(arg.name(), api_def_arg.name());
bool is_list = ArgIsList(arg);
output_types.push_back(
strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output"));
output_names.push_back(AvoidCPPKeywords(arg.name()));
output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
is_list_output.push_back(is_list);
}
strings::StrAppend(&comment, "\nReturns:\n");
if (op_def.output_arg_size() == 0) { // No outputs.
if (graph_op_def.output_arg_size() == 0) { // No outputs.
strings::StrAppend(&comment, "* the created `Operation`\n");
} else if (op_def.output_arg_size() == 1) { // One output
} else if (graph_op_def.output_arg_size() == 1) { // One output
if (is_list_output[0]) {
strings::StrAppend(&comment, "* `OutputList`: ");
} else {
strings::StrAppend(&comment, "* `Output`: ");
}
if (op_def.output_arg(0).description().empty()) {
strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(),
if (api_def.out_arg(0).description().empty()) {
strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(),
" tensor.\n");
} else {
// TODO(josh11b): Word wrap this.
strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n");
strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n");
}
} else { // Multiple outputs.
for (int i = 0; i < op_def.output_arg_size(); ++i) {
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
if (is_list_output[i]) {
strings::StrAppend(&comment, "* `OutputList`");
} else {
strings::StrAppend(&comment, "* `Output`");
}
strings::StrAppend(&comment, " ", output_names[i]);
if (op_def.output_arg(i).description().empty()) {
if (api_def.out_arg(i).description().empty()) {
strings::StrAppend(&comment, "\n");
} else {
// TODO(josh11b): Word wrap this.
strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(),
strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(),
"\n");
}
}
@ -564,19 +668,20 @@ string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
const auto& attr(graph_op_def.attr(i));
const auto& api_def_attr(api_def.attr(i));
// If attr will be inferred or it doesn't have a default value, don't
// add it to the struct.
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
!attr.has_default_value()) {
!api_def_attr.has_default_value()) {
continue;
}
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
const string camel_case_name = ToCamelCase(attr.name());
const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
const string suffix =
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
const string attr_func_def =
@ -584,22 +689,25 @@ string OpInfo::GetOpAttrStruct() const {
attr_type_name, use_const ? "&" : "");
string attr_comment;
if (!attr.description().empty()) {
strings::StrAppend(&attr_comment, attr.description(), "\n\n");
if (!api_def_attr.description().empty()) {
strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n");
}
strings::StrAppend(&attr_comment, "Defaults to ",
SummarizeAttrValue(attr.default_value()), "\n");
SummarizeAttrValue(api_def_attr.default_value()), "\n");
attr_comment = MakeComment(attr_comment, " ");
strings::StrAppend(&setters, attr_comment);
strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n");
strings::StrAppend(&setters, " Attrs ret = *this;\n");
strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n");
strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(),
"_ = x;\n");
strings::StrAppend(&setters, " return ret;\n }\n\n");
strings::StrAppend(
&struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ",
PrintAttrValue(op_def.name(), attr.default_value()), ";\n");
&struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(),
"_ = ",
PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
";\n");
}
if (struct_fields.empty()) {
@ -676,17 +784,18 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
// Add the static functions to set optional attrs
if (has_optional_attrs) {
strings::StrAppend(&class_decl, "\n");
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
const auto& attr(graph_op_def.attr(i));
const auto& api_def_attr(api_def.attr(i));
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
!attr.has_default_value()) {
!api_def_attr.has_default_value()) {
continue;
}
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
const string camel_case_name = ToCamelCase(attr.name());
const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
const string suffix =
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
const string attr_func_def = strings::StrCat(
@ -726,11 +835,11 @@ void OpInfo::GetOutput(string* out) const {
strings::StrCat("if (!", scope_str, ".ok()) return;");
// No outputs.
if (op_def.output_arg_size() == 0) {
if (graph_op_def.output_arg_size() == 0) {
strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
return;
}
if (op_def.output_arg_size() == 1) {
if (graph_op_def.output_arg_size() == 1) {
// One output, no need for NameRangeMap
if (is_list_output[0]) {
strings::StrAppend(out,
@ -752,7 +861,7 @@ void OpInfo::GetOutput(string* out) const {
".UpdateStatus(_status_);\n", " return;\n");
strings::StrAppend(out, " }\n\n");
for (int i = 0; i < op_def.output_arg_size(); ++i) {
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
const string arg_range = strings::StrCat(
"_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
if (is_list_output[i]) {
@ -776,11 +885,13 @@ string OpInfo::GetConstructorBody() const {
strings::StrAppend(&body, " ", return_on_error, "\n");
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::",
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(",
scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n");
for (int i = 0; i < graph_op_def.input_arg_size(); ++i) {
const auto& arg(graph_op_def.input_arg(i));
const auto& api_def_arg(api_def.in_arg(i));
strings::StrAppend(
&body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::",
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ",
AvoidCPPKeywords(api_def_arg.rename_to()), ");\n");
strings::StrAppend(&body, " ", return_on_error, "\n");
}
@ -791,19 +902,21 @@ string OpInfo::GetConstructorBody() const {
&body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
graph_op_def.name(), "\")\n");
const string spaces = " ";
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n");
for (int i = 0; i < api_def.in_arg_size(); ++i) {
const auto& arg(api_def.in_arg(i));
strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n");
}
for (int i = 0; i < op_def.attr_size(); ++i) {
for (int i = 0; i < api_def.attr_size(); ++i) {
const auto& graph_attr(graph_op_def.attr(i));
const auto& attr(op_def.attr(i));
if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) {
const auto& api_def_attr(api_def.attr(i));
if (inferred_input_attrs.find(api_def_attr.name()) !=
inferred_input_attrs.end()) {
continue;
}
const string attr_name = attr.has_default_value()
? strings::StrCat("attrs.", attr.name(), "_")
: AvoidCPPKeywords(attr.name());
const string attr_name =
api_def_attr.has_default_value()
? strings::StrCat("attrs.", api_def_attr.rename_to(), "_")
: AvoidCPPKeywords(api_def_attr.rename_to());
strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
attr_name, ")\n");
}
@ -845,10 +958,10 @@ void OpInfo::WriteClassDef(WritableFile* cc) const {
TF_CHECK_OK(cc->Append(class_def));
}
void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def,
void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases, WritableFile* h,
WritableFile* cc) {
OpInfo op_info(graph_op_def, interface_op_def, aliases);
OpInfo op_info(graph_op_def, api_def, aliases);
op_info.WriteClassDecl(h);
op_info.WriteClassDef(cc);
@ -943,8 +1056,9 @@ string MakeInternal(const string& fname) {
} // namespace
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
const string& dot_cc_fname, const string& overrides_fnames) {
void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
const string& dot_h_fname, const string& dot_cc_fname,
const string& overrides_fnames) {
Env* env = Env::Default();
// Load the override map.
@ -984,24 +1098,23 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname,
// code depends on it.
if (graph_op_def.name() == "Const") continue;
// Incorporate overrides from override_map.
OpDef interface_op_def = graph_op_def;
const OpGenOverride* op_override =
override_map.ApplyOverride(&interface_op_def);
std::vector<string> aliases;
if (op_override) {
if (op_override->skip()) continue;
aliases.assign(op_override->alias().begin(), op_override->alias().end());
if (op_override->hide()) {
// Write hidden ops to _internal.h and _internal.cc.
WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(),
internal_cc.get());
continue;
}
}
const auto* api_def = api_def_map.GetApiDef(graph_op_def.name());
std::vector<string> aliases;
if (api_def->visibility() == ApiDef::SKIP) continue;
// First endpoint is canonical, the rest are aliases.
for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size();
++endpoint_i) {
aliases.push_back(api_def->endpoint(endpoint_i).name());
}
if (api_def->visibility() == ApiDef::HIDDEN) {
// Write hidden ops to _internal.h and _internal.cc.
WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(),
internal_cc.get());
continue;
}
// This isn't a hidden op, write it to the main files.
WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get());
WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get());
}
FinishFiles(false, h.get(), cc.get(), op_header_guard);

View File

@ -17,13 +17,15 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
/// Result is written to files dot_h and dot_cc.
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
const string& dot_cc_fname, const string& overrides_fnames);
void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
const string& dot_h_fname, const string& dot_cc_fname,
const string& overrides_fnames);
} // namespace tensorflow

View File

@ -16,7 +16,11 @@ limitations under the License.
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/types.h"
@ -24,10 +28,28 @@ namespace tensorflow {
namespace {
void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
const std::string& overrides_fnames, bool include_internal) {
const std::string& overrides_fnames, bool include_internal,
const std::vector<string>& api_def_dirs) {
OpList ops;
OpRegistry::Global()->Export(include_internal, &ops);
WriteCCOps(ops, dot_h, dot_cc, overrides_fnames);
ApiDefMap api_def_map(ops);
if (!api_def_dirs.empty()) {
Env* env = Env::Default();
// Only load files that correspond to "ops".
for (const auto& op : ops.op()) {
for (const auto& api_def_dir : api_def_dirs) {
const std::string api_def_file_pattern =
io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
if (env->FileExists(api_def_file_pattern).ok()) {
TF_CHECK_OK(api_def_map.LoadFile(env, api_def_file_pattern));
}
}
}
}
api_def_map.UpdateDocs();
WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames);
}
} // namespace
@ -35,18 +57,24 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc != 5) {
// TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt
// as an argument.
if (argc != 6) {
for (int i = 1; i < argc; ++i) {
fprintf(stderr, "Arg %d = %s\n", i, argv[i]);
}
fprintf(stderr,
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n"
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal "
"api_def_dirs1,api_def_dir2 ...\n"
" include_internal: 1 means include internal ops\n",
argv[0]);
exit(1);
}
bool include_internal = tensorflow::StringPiece("1") == argv[4];
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal);
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
argv[5], ",", tensorflow::str_util::SkipEmpty());
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal,
api_def_dirs);
return 0;
}

View File

@ -0,0 +1,195 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
// TODO(annarev): Remove this op_gen_overrides.pbtxt reference.
// It is needed only because WriteCCOps takes it as an argument.
constexpr char kOverridesFnames[] =
"tensorflow/cc/ops/op_gen_overrides.pbtxt";
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
input_arg {
name: "images"
description: "Images to process."
}
input_arg {
name: "dim"
description: "Description for dim."
type: DT_FLOAT
}
output_arg {
name: "output"
description: "Description for output."
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
description: "Type for images"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
}
}
default_value {
i: 1
}
}
summary: "Summary for op Foo."
description: "Description for op Foo."
}
)";
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(s.contains(expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
EXPECT_FALSE(s.contains(expected))
<< "'" << s << "' contains '" << expected << "'";
}
void ExpectSubstrOrder(const string& s, const string& before,
const string& after) {
int before_pos = s.find(before);
int after_pos = s.find(after);
ASSERT_NE(std::string::npos, before_pos);
ASSERT_NE(std::string::npos, after_pos);
EXPECT_LT(before_pos, after_pos)
<< before << " is not before " << after << " in " << s;
}
// Runs WriteCCOps and stores output in (internal_)cc_file_path and
// (internal_)h_file_path.
void GenerateCcOpFiles(Env* env, const OpList& ops,
const ApiDefMap& api_def_map, string* h_file_text,
string* internal_h_file_text) {
const string& tmpdir = testing::TmpDir();
const auto h_file_path = io::JoinPath(tmpdir, "test.h");
const auto cc_file_path = io::JoinPath(tmpdir, "test.cc");
const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h");
const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc");
WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames);
TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text));
TF_ASSERT_OK(
ReadFileToString(env, internal_h_file_path, internal_h_file_text));
}
TEST(CcOpGenTest, TestVisibilityChangedToHidden) {
const string api_def = R"(
op {
graph_op_name: "Foo"
visibility: HIDDEN
}
)";
Env* env = Env::Default();
OpList op_defs;
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
ApiDefMap api_def_map(op_defs);
string h_file_text, internal_h_file_text;
// Without ApiDef
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectHasSubstr(h_file_text, "class Foo");
ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo");
// With ApiDef
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectHasSubstr(internal_h_file_text, "class Foo");
ExpectDoesNotHaveSubstr(h_file_text, "class Foo");
}
TEST(CcOpGenTest, TestArgNameChanges) {
const string api_def = R"(
op {
graph_op_name: "Foo"
arg_order: "dim"
arg_order: "images"
}
)";
Env* env = Env::Default();
OpList op_defs;
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
ApiDefMap api_def_map(op_defs);
string cc_file_text, h_file_text;
string internal_cc_file_text, internal_h_file_text;
// Without ApiDef
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectSubstrOrder(h_file_text, "Input images", "Input dim");
// With ApiDef
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectSubstrOrder(h_file_text, "Input dim", "Input images");
}
TEST(CcOpGenTest, TestEndpoints) {
const string api_def = R"(
op {
graph_op_name: "Foo"
endpoint {
name: "Foo1"
}
endpoint {
name: "Foo2"
}
}
)";
Env* env = Env::Default();
OpList op_defs;
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
ApiDefMap api_def_map(op_defs);
string cc_file_text, h_file_text;
string internal_cc_file_text, internal_h_file_text;
// Without ApiDef
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectHasSubstr(h_file_text, "class Foo {");
ExpectDoesNotHaveSubstr(h_file_text, "class Foo1");
ExpectDoesNotHaveSubstr(h_file_text, "class Foo2");
// With ApiDef
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectHasSubstr(h_file_text, "class Foo1");
ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2");
ExpectDoesNotHaveSubstr(h_file_text, "class Foo {");
}
} // namespace
} // namespace tensorflow

View File

@ -21,6 +21,9 @@ namespace tensorflow {
/// Tag for the `gpu` graph.
constexpr char kSavedModelTagGpu[] = "gpu";
/// Tag for the `tpu` graph.
constexpr char kSavedModelTagTpu[] = "tpu";
/// Tag for the `serving` graph.
constexpr char kSavedModelTagServe[] = "serve";

View File

@ -24,10 +24,12 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
@test_util.with_c_api
class FunctionTest(XLATestCase):
def testFunction(self):

View File

@ -623,11 +623,12 @@ class FunctionalizeCond {
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library)
: clusters_(graph->num_node_ids()), library_(library), graph_(graph) {}
// Returns a vector of Merge nodes from the clustered graph where the nodes
// Returns a vector of Switch nodes from the clustered graph where the nodes
// are sorted by the number of switch nodes minus number of merge nodes
// from a root of the clustered graph to the given Merge node, with ties
// broken by the representative of the Cluster.
std::vector<std::pair<int, Cluster*>> SortedMergeNodes();
// broken by the representative of the Cluster. This corresponds to sorting by
// nesting depth, from deepest nested to outermost.
std::vector<std::pair<int, Cluster*>> SortedSwitchNodes();
// Returns whether the graph has no conditionals.
bool NoConditionals() const { return merge_nodes_.empty(); }
@ -654,15 +655,17 @@ class FunctionalizeCond {
// extracting the bodies needed for the then and else branch, creates a XlaIf
// node, removing the nodes of the branches from the graph and replacing the
// merge node with a XlaIf.
Status ConvertMergeToXlaIf(Cluster* merge_cluster);
Status ConvertCorrespondingMergeToXlaIf(Cluster* switch_cluster);
// Removes a Switch cluster feeding directly into a Merge cluster by removing
// the Switch and Merge nodes and collapsing into a single cluster.
Status RemoveTrivialMerge(Cluster* merge_cluster);
Status RemoveTrivialSwitch(Cluster* switch_cluster);
// Returns the switch cluster corresponding to the merge node. This function
// only returns the switch cluster in the simple case where we have a switch
// node is the entry of a diamond corresponding to a conditional:
// Returns the merge cluster corresponding to the switch node. This function
// only returns the merge cluster in the case where we have a switch node that
// is the single entry point for all paths to a common merge cluster, this
// merge cluster may be created by combining multiple merge clusters, that
// share the switch cluster as common ancestor, together.
//
// Switch
// / \
@ -671,8 +674,9 @@ class FunctionalizeCond {
// merge_cluster
//
// Note: either of the branches may be empty. The case where both branches are
// empty is handled by RemoveTrivialMerge.
gtl::optional<Cluster*> GetSwitchCluster(const Cluster& merge_cluster);
// empty is handled by RemoveTrivialSwitch.
gtl::optional<Cluster*> CreateCorrespondingMergeCluster(
Cluster* switch_cluster);
// Determines the arguments needed as input to the Merge cluster originating
// from the Switch cluster.
@ -793,6 +797,10 @@ bool IsDeadSwitch(const Node* node) {
}
void FunctionalizeCond::CreateClusters() {
ClusterHandle source_cluster = ClusterHandle(Graph::kSourceId);
auto& source = clusters_.at(source_cluster);
std::deque<std::pair<ClusterHandle, std::deque<Node*>>> workqueue;
workqueue.push_back({source_cluster, {}});
for (Node* node : graph_->nodes()) {
if (IsSwitch(node)) {
switch_nodes_.insert(node);
@ -801,6 +809,12 @@ void FunctionalizeCond::CreateClusters() {
}
ClusterHandle& cluster = clusters_.at(node).Get();
cluster = ClusterHandle(node->id());
// Group all source clusters together.
if (node->IsSource() || node->in_edges().empty()) {
clusters_.at(node).Merge(&source);
source.Merge(&clusters_.at(node));
workqueue.front().second.push_back(node);
}
}
// If there are no Merge nodes, then terminate.
@ -815,20 +829,118 @@ void FunctionalizeCond::CreateClusters() {
// conservatively assuming all merge nodes become XlaIf nodes.
clusters_.resize(clusters_.size() + merge_nodes_.size());
// Merge a cluster with its input, unless the input is a Switch node or
// the node is a Merge node.
for (const Node* node : graph_->nodes()) {
if (IsMerge(node) || IsSwitch(node) || !node->IsOp()) {
continue;
}
for (const Node* in : node->in_nodes()) {
if (in->IsOp() && !IsSwitch(in) && !IsMerge(in)) {
clusters_.at(node).Merge(&clusters_.at(in));
std::unordered_set<Node*> marked;
while (!workqueue.empty()) {
auto cluster_queue = workqueue.front();
VLOG(4) << "Cluster: " << cluster_queue.first << " Queue: {"
<< str_util::Join(cluster_queue.second, ",",
[](string* output, const Node* node) {
strings::StrAppend(output, node->id());
})
<< "}";
UnionFind<ClusterHandle>& repr = clusters_.at(cluster_queue.first);
workqueue.pop_front();
std::deque<Node*> switch_nodes;
std::deque<Node*> merge_nodes;
std::unordered_set<Node*> cluster_member;
while (!cluster_queue.second.empty()) {
// Iterate node workqueue and flow forward merging all nodes reachable
// that are neither a Switch or a Merge and whose inputs are all part of
// the same cluster.
Node* cur = cluster_queue.second.front();
cluster_queue.second.pop_front();
if (marked.find(cur) != marked.end()) {
continue;
}
if (IsMerge(cur)) {
merge_nodes.push_back(cur);
marked.insert(cur);
continue;
}
if (IsSwitch(cur)) {
switch_nodes.push_back(cur);
marked.insert(cur);
continue;
}
clusters_.at(cur).Merge(&repr);
cluster_member.insert(cur);
for (Node* out : cur->out_nodes()) {
bool all_ancestors_in_cluster = true;
for (Node* in : out->in_nodes()) {
if (IsMerge(out)) {
merge_nodes.push_back(out);
}
if (IsSwitch(out)) {
switch_nodes.push_back(out);
}
if (cluster_member.find(in) == cluster_member.end()) {
all_ancestors_in_cluster = false;
break;
}
}
if (all_ancestors_in_cluster && out->IsOp()) {
cluster_queue.second.push_back(out);
marked.insert(cur);
}
}
}
// Group all source clusters together.
if (node->IsSource() || node->in_edges().empty()) {
clusters_.at(node).Merge(&clusters_.at(ClusterHandle(Graph::kSourceId)));
VLOG(4) << "Switches: {"
<< str_util::Join(switch_nodes, ",",
[](string* output, const Node* node) {
strings::StrAppend(output, node->id());
})
<< "}";
// Merge Switch nodes with common predicate.
std::unordered_map<Node*, std::vector<Node*>> predicate_to_switch;
for (Node* node : switch_nodes) {
Node* tmp;
TF_CHECK_OK(node->input_node(1, &tmp));
predicate_to_switch[tmp].push_back(node);
}
for (auto kv : predicate_to_switch) {
Node* first = kv.second.front();
for (Node* switch_node : kv.second) {
clusters_.at(first).Merge(&clusters_.at(switch_node));
}
}
// Enqueue each edge of the switch node separately. That is, group all the
// nodes that are due to the true/false edge of the switch together and
// consider all nodes that only have a control dependency on the switch node
// separately. We want to group together all nodes that are part of the same
// branch, as these will be extracted into the `then` and `else` functions
// of the functional if. The ops due to control edges are different as they
// could be involved with either branch and merging them here could result
// in invalid graphs.
for (auto kv : predicate_to_switch) {
ClusterHandle none = ClusterHandle(-1);
ClusterHandle first[2] = {none, none};
std::deque<Node*>* queue[2];
for (auto switch_node : kv.second) {
for (const auto e : switch_node->out_edges()) {
if (IsSwitch(e->dst()) || IsMerge(e->dst())) {
continue;
}
// Control edges are enqueued on their own.
if (e->IsControlEdge()) {
workqueue.push_back({Representative(e->dst()), {e->dst()}});
continue;
}
// Combine all outputs of the same output port of a switch cluster
// into the same workqueue entry.
if (first[e->src_output()] == none) {
ClusterHandle repr = Representative(e->dst());
first[e->src_output()] = repr;
workqueue.push_back({repr, {}});
queue[e->src_output()] = &workqueue.back().second;
}
clusters_.at(first[e->src_output()]).Merge(&clusters_.at(e->dst()));
queue[e->src_output()]->push_back(e->dst());
}
}
}
}
}
@ -910,74 +1022,60 @@ void FunctionalizeCond::CreateClusteredGraph() {
update_cluster_for_node(node).merge_nodes.insert(node);
}
// Merge Switch nodes with common predicate.
std::unordered_map<Node*, std::vector<Node*>> predicate_to_switch;
for (Node* node : switch_nodes_) {
Node* tmp;
TF_CHECK_OK(node->input_node(1, &tmp));
predicate_to_switch[tmp].push_back(node);
}
for (auto kv : predicate_to_switch) {
Cluster& first = clustered_graph_.at(Representative(kv.second.front()));
for (Node* switch_node : kv.second) {
ClusterHandle handle = Representative(switch_node);
Cluster& cluster = clustered_graph_.at(handle);
ContractEdge(&cluster, &first, /*remove_from_graph=*/true);
}
}
// Merge Merge nodes with common input together.
for (Node* node : merge_nodes_) {
Cluster& cluster = clustered_graph_.at(Representative(node));
for (const Node* in : node->in_nodes()) {
if (!in->IsOp()) {
continue;
}
Cluster& cluster_node_in = clustered_graph_.at(Representative(in));
// ContractEdge can modify out_nodes of cluster_node_in, so traverse
// over out_nodes assuming it does.
for (auto it = cluster_node_in.out_nodes.begin();
it != cluster_node_in.out_nodes.end();) {
if (!(*it)->merge_nodes.empty()) {
ContractEdge(*it++, &cluster, /*remove_from_graph=*/true);
} else {
++it;
}
}
}
}
VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_);
VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_);
}
gtl::optional<FunctionalizeCond::Cluster*> FunctionalizeCond::GetSwitchCluster(
const Cluster& merge_cluster) {
VLOG(3) << "GetSwitchCluster for " << merge_cluster.representative;
gtl::optional<Cluster*> switch_cluster;
if (merge_cluster.in_nodes.size() > 2) {
return gtl::nullopt;
gtl::optional<FunctionalizeCond::Cluster*>
FunctionalizeCond::CreateCorrespondingMergeCluster(Cluster* switch_cluster) {
VLOG(3) << "CreateCorrespondingMergeCluster for "
<< switch_cluster->representative;
std::unordered_set<Cluster*> merges;
std::unordered_set<Cluster*> dominated;
dominated.insert(switch_cluster);
std::deque<Cluster*> queue;
auto enqueue_or_update_merge = [this, &queue, &merges](Cluster* c) {
if (c->merge_nodes.empty()) {
queue.push_back(c);
} else {
merges.insert(c);
}
};
// Enqueue all the outputs of the switch cluster in the workqueue.
for (auto* out : switch_cluster->out_nodes) {
enqueue_or_update_merge(out);
}
for (Cluster* in : merge_cluster.in_nodes) {
Cluster* cluster = in;
if (in->switch_nodes.empty()) {
if (in->in_nodes.size() != 1 || in->out_nodes.size() != 1) {
std::unordered_set<Cluster*> visited;
while (!queue.empty()) {
Cluster* cur = queue.front();
queue.pop_front();
if (visited.find(cur) != visited.end()) {
continue;
}
visited.insert(cur);
// Ensure all inputs to the current node are in the dominated set.
for (Cluster* in : cur->in_nodes) {
if (dominated.find(in) == dominated.end()) {
return gtl::nullopt;
}
// There is only a single `in` cluster.
cluster = *in->in_nodes.begin();
}
if (cluster->switch_nodes.empty()) {
return gtl::nullopt;
}
if (switch_cluster.has_value() && *switch_cluster != cluster) {
return gtl::nullopt;
} else {
switch_cluster = cluster;
for (Cluster* out : cur->out_nodes) {
// No switch nodes beyond the entry one is expected.
if (!out->switch_nodes.empty()) {
return gtl::nullopt;
}
enqueue_or_update_merge(out);
}
}
return switch_cluster;
auto it = merges.begin();
Cluster* merge_cluster = *it;
for (++it; it != merges.end(); ++it) {
ContractEdge(*it, merge_cluster);
}
// TODO(jpienaar): Clean up graph, merging nodes.
return merge_cluster;
}
xla::StatusOr<FunctionalizeCond::CondArgs> FunctionalizeCond::DetermineCondArgs(
@ -1221,11 +1319,11 @@ void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) {
}
}
Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) {
Cluster* switch_cluster = *merge_cluster->in_nodes.begin();
if (switch_cluster->switch_nodes.empty()) {
Status FunctionalizeCond::RemoveTrivialSwitch(Cluster* switch_cluster) {
Cluster* merge_cluster = *switch_cluster->out_nodes.begin();
if (merge_cluster->merge_nodes.empty()) {
return errors::FailedPrecondition(
"Not a trivial merge: no Switch node feeding into Merge node");
"Not a trivial switch: no Merge node feeding into Switch node");
}
for (auto it = merge_cluster->merge_nodes.begin();
@ -1252,17 +1350,25 @@ Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) {
return Status::OK();
}
Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) {
VLOG(1) << "ConvertMergeToXlaIf for " << merge_cluster->representative;
gtl::optional<Cluster*> switch_cluster = GetSwitchCluster(*merge_cluster);
if (!switch_cluster.has_value()) {
Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf(
Cluster* switch_cluster) {
VLOG(1) << "ConvertMergeToXlaIf for " << switch_cluster->representative;
gtl::optional<Cluster*> maybe_merge =
CreateCorrespondingMergeCluster(switch_cluster);
if (!maybe_merge.has_value()) {
return errors::FailedPrecondition(
"Merge cluster was not part of a simple conditional in the clustered "
"graph. Graph nodes in merge cluster ",
NodesToString(merge_cluster->merge_nodes));
"Switch cluster was not part of a simple conditional in the clustered "
"graph. Graph nodes in switch cluster ",
NodesToString(switch_cluster->switch_nodes));
}
Cluster* merge_cluster = *maybe_merge;
if (merge_cluster->merge_nodes.empty()) {
return errors::Internal(
"Merge node in clustered graph contains no merge nodes: ",
merge_cluster->representative.ToString());
}
TF_ASSIGN_OR_RETURN(auto cond_args,
DetermineCondArgs(*merge_cluster, **switch_cluster));
DetermineCondArgs(*merge_cluster, *switch_cluster));
// Sort the outputs by ID to produce more stable output.
std::vector<Node*> outputs(merge_cluster->merge_nodes.begin(),
@ -1278,7 +1384,7 @@ Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) {
// Remove the old nodes from the graph_ and contract the edges of the
// clustered graph.
for (auto in : merge_cluster->in_nodes) {
if (in != *switch_cluster) {
if (in != switch_cluster) {
RemoveClusterNodes(in);
}
}
@ -1286,20 +1392,20 @@ Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) {
RemoveUnusedArgs(cond_args.args);
auto in_nodes = merge_cluster->in_nodes;
for (auto it = in_nodes.begin(); it != in_nodes.end();) {
ContractEdge(*it++, merge_cluster);
ContractEdge(*it++, switch_cluster);
}
ContractEdge(*switch_cluster, merge_cluster);
clusters_[if_node].Get() = ClusterHandle(merge_cluster->representative);
ContractEdge(merge_cluster, switch_cluster);
clusters_[if_node].Get() = ClusterHandle(switch_cluster->representative);
return Status::OK();
}
std::vector<std::pair<int, FunctionalizeCond::Cluster*>>
FunctionalizeCond::SortedMergeNodes() {
FunctionalizeCond::SortedSwitchNodes() {
VLOG(2) << "ProcessClusteredGraph";
std::stack<std::pair<int, Cluster*>> stack;
// Initialize with the source node.
stack.push({0, &clustered_graph_[ClusterHandle(Graph::kSourceId)]});
stack.push({0, &clustered_graph_[Representative(graph_->source_node())]});
// Perform a depth-first traversal of the clustered graph computing the
// switch-merge depth.
@ -1317,10 +1423,10 @@ FunctionalizeCond::SortedMergeNodes() {
size_t new_depth = depth;
if (!n->merge_nodes.empty()) {
queue.emplace_back(depth, n);
--new_depth;
}
if (!n->switch_nodes.empty()) {
queue.emplace_back(depth, n);
++new_depth;
}
for (Cluster* e : n->out_nodes) {
@ -1350,25 +1456,30 @@ Status FunctionalizeCond::Functionalize(Graph* graph,
}
fc.CreateClusteredGraph();
auto queue = fc.SortedMergeNodes();
auto queue = fc.SortedSwitchNodes();
for (auto it = queue.begin(); it != queue.end();) {
Cluster* merge_cluster = (*it).second;
Cluster* switch_cluster = (*it).second;
++it;
if (merge_cluster->in_nodes.size() == 1) {
TF_RETURN_IF_ERROR(fc.RemoveTrivialMerge(merge_cluster));
if (switch_cluster->out_nodes.size() == 1) {
TF_RETURN_IF_ERROR(fc.RemoveTrivialSwitch(switch_cluster));
} else {
TF_RETURN_IF_ERROR(fc.ConvertMergeToXlaIf(merge_cluster));
TF_RETURN_IF_ERROR(fc.ConvertCorrespondingMergeToXlaIf(switch_cluster));
}
// Contract newly Merge free merge_cluster with incoming nodes without
// Contract newly Switch free switch_cluster with outgoing nodes without
// Switch or Merge nodes.
std::vector<Cluster*> in_nodes(merge_cluster->in_nodes.begin(),
merge_cluster->in_nodes.end());
for (auto in : in_nodes) {
if (in->merge_nodes.empty() && in->switch_nodes.empty()) {
fc.ContractEdge(in, merge_cluster);
for (auto& nodes : {switch_cluster->out_nodes, switch_cluster->in_nodes}) {
std::vector<Cluster*> copy_nodes(nodes.begin(), nodes.end());
for (auto* node : copy_nodes) {
if (node->merge_nodes.empty() && node->switch_nodes.empty()) {
fc.ContractEdge(node, switch_cluster);
}
}
}
VLOG(3) << "Graph with clusters: "
<< DebugString(*fc.graph_, &fc.clusters_);
VLOG(3) << "ClusteredGraph: " << DebugString(fc.clustered_graph_);
}
if (!fc.switch_nodes_.empty()) {

View File

@ -153,6 +153,7 @@ bool ComputationBuilder::MakeWindow(
} else {
dim->set_window_dilation(1);
}
dim->set_window_reversal(false);
}
return true;
}
@ -1163,6 +1164,34 @@ ComputationDataHandle ComputationBuilder::ConvertElementType(
return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::BitcastConvertType(
const ComputationDataHandle& operand, PrimitiveType new_element_type) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
if (!shape_status.ok()) {
first_error_ = shape_status.status();
return ComputationDataHandle();
}
std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
ConvertRequest request;
*request.mutable_operand() = operand;
request.set_new_element_type(new_element_type);
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_bitcast_convert_request() = request;
AddCommonFieldsToOpRequest(&op_request);
OpResponse response;
VLOG(2) << "making bitcast convert request";
Status s = client_->stub()->Op(&op_request, &response);
return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::SquareF32(
const ComputationDataHandle& operand) {
return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0),

View File

@ -121,14 +121,10 @@ class ComputationBuilder {
// result, OpMetadata is set on the Computation Builder. All subsequent
// instructions generated via this Computation Builder will have the same
// OpMetadata attached until a call to ClearOpMetdata.
void SetOpMetadata(const OpMetadata& metadata) {
metadata_ = metadata;
}
void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
// Clears the HloMetadata state.
void ClearOpMetadata() {
metadata_.Clear();
}
void ClearOpMetadata() { metadata_.Clear(); }
// Sets an OpSharding that will be attached to all instructions until cleared.
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
@ -673,6 +669,13 @@ class ComputationBuilder {
ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
PrimitiveType new_element_type);
// Enqueues a no-op instruction onto the computation that changes
// the element type of the operand array to primitive_type. The
// bit-widths of the source and destination element types must be
// identical.
ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand,
PrimitiveType new_element_type);
// Enqueues a float32 reciprocal instruction onto the computation.
// (float32 is specified as there is an implicit float32 -1.0f constant
// exponent).

View File

@ -275,9 +275,6 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
device_ordinal, options));
}
// Copy the literal data to the device with the given ordinal and return as a
// ScopedShapedBuffer. The given memory allocator is used for device memory
// allocation.
StatusOr<std::unique_ptr<ScopedShapedBuffer>>
LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal,
DeviceMemoryAllocator* allocator) {
@ -298,8 +295,6 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal,
return std::move(scoped_buffer);
}
// Copy the data from the device contained in the given ShapedBuffer and
// return as a Literal.
StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(
@ -309,4 +304,22 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
shaped_buffer);
}
Status LocalClient::TransferToInfeedLocal(const Literal& literal,
int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
return backend().transfer_manager()->TransferLiteralToInfeed(executor,
literal);
}
StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
const Shape& shape, int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
auto literal = MakeUnique<Literal>();
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, shape, literal.get()));
return std::move(literal);
}
} // namespace xla

View File

@ -162,6 +162,20 @@ class LocalClient : public Client {
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer);
// Transfer the given literal to the infeed queue of the given device.
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with
// Client::TransferToInfeed.
Status TransferToInfeedLocal(const Literal& literal, int device_ordinal);
// Transfer and return a value of the given shape from the outfeed of the
// given device.
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with
// Client::TransferFromOutfeed.
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
const Shape& shape, int device_ordinal);
// Returns the platform that the underlying service targets.
perftools::gputools::Platform* platform() const;

View File

@ -1304,6 +1304,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
@ -1638,10 +1639,14 @@ cc_library(
deps = [
":buffer_liveness",
":hlo",
":hlo_alias_analysis",
":hlo_dce",
":hlo_graph_dumper",
":hlo_ordering",
":hlo_pass",
":liveness_util",
":logical_buffer",
":tuple_points_to_analysis",
":tuple_simplifier",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@ -1656,15 +1661,17 @@ tf_cc_test(
deps = [
":copy_insertion",
":hlo",
":hlo_graph_dumper",
":hlo_matchers",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)

View File

@ -1108,9 +1108,15 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
Literal::One(rhs->shape().element_type()).CloneToUnique()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
// broadcasting at HLO level.
auto* broadcast_one = computation_->AddInstruction(
HloInstruction::CreateBroadcast(power->shape(), one, {}));
return ReplaceWithNewInstruction(
power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
one, lhs));
broadcast_one, lhs));
}
return Status::OK();
}
@ -1398,6 +1404,15 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
auto operand = reduce_window->mutable_operand(0);
const Window& window = reduce_window->window();
auto function = reduce_window->to_apply();
if (ShapeUtil::IsScalar(operand->shape())) {
TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape()));
return ReplaceWithNewInstruction(
reduce_window,
HloInstruction::CreateMap(reduce_window->shape(),
{operand, reduce_window->mutable_operand(1)},
function));
}
VLOG(10) << "Considering folding Pad: " << operand->ToString()
<< "\ninto reduce-window: " << reduce_window->ToString();

View File

@ -761,8 +761,10 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Divide(op::Constant(), param0));
EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
EXPECT_THAT(root, op::Divide(op::Broadcast(), param0));
EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast);
EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement<float>(),
1);
}
TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {

View File

@ -1265,7 +1265,6 @@ const LogicalBuffer* AddBufferToColocatedSet(
// CopyInsertion ensures root points-to set is unambiguous and distinct.
const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
DCHECK(!points_to.IsAmbiguous());
DCHECK(points_to.IsDistinct());
colocated_set->push_back(points_to.element(index)[0]);
return colocated_set->back();
}

View File

@ -1538,8 +1538,6 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto output1 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@ -1556,10 +1554,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
auto body1 =
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({input0, weights0, output1}));
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
module->AddEntryComputation(builder.Build());
RunCopyInsertion(module.get());
@ -1676,11 +1672,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
auto root_add = builder.AddInstruction(HloInstruction::CreateBinary(
while0->shape(), HloOpcode::kAdd, while0, while1));
module->AddEntryComputation(builder.Build());
while0->shape(), HloOpcode::kAdd, gte0, gte1));
RunCopyInsertion(module.get());
module->AddEntryComputation(builder.Build());
{
FlattenCallGraph flatten;
@ -1688,22 +1687,22 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
EXPECT_TRUE(result);
}
RunCopyInsertion(module.get());
auto sequence =
CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
// To trigger b/38494731, we want a specific Hlo sequence for the
// root computation, so we overwrite that entry with a manually
// crafted sequence.
std::vector<const HloInstruction*> sequence_for_buffer_assigment = {
input1, weights1, one, output1, tuple1, while1, input0,
weights0, zero, output0, tuple0, while0, root_add};
sequence[module->entry_computation()] = {
input1, weights1, one, output1, while1->operand(0), while1,
input0, weights0, zero, output0, while0->operand(0), while0,
gte0, gte1, root_add};
// If this ASSERT_TRUE fails, we constructed a bogus sequence above
// and this test itself is buggy.
ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment));
sequence[module->entry_computation()] =
std::move(sequence_for_buffer_assigment);
ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()]));
auto assignment =
BufferAssigner::Run(
@ -1716,55 +1715,6 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
}
// Test buffer assignment for while nodes with multiple uses.
// TODO(b/37245345): Fix buffer assignment for this case.
TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) {
auto module = xla::MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder(TestName());
auto input0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape_, "input0"));
auto weights0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
auto body0 =
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({input0, weights0, output0}));
auto while0 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0));
auto get0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
auto get1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
builder.AddInstruction(
HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1));
module->AddEntryComputation(builder.Build());
RunCopyInsertion(module.get());
{
FlattenCallGraph flatten;
TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
EXPECT_TRUE(result);
}
auto assignment = RunBufferAssignment(module.get());
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
}
TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
auto module = xla::MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder("entry");

View File

@ -97,21 +97,32 @@ class Compiler {
// Returns the ID of the platform that this compiler targets.
virtual perftools::gputools::Platform::Id PlatformId() const = 0;
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
// module.
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* executor) = 0;
// Compiles the HLO module for execution on a device given by the executor,
// and returns an executable object or an error status. Takes ownership of the
// HLO module and is free to transform it.
// and returns an executable object or an error status. No HLO passes are
// applied to module. Generally a module should be passed through RunHloPasses
// prior to calling this method because the some HLO passes are required for
// correctness. Takes ownership of the HLO module and is free to transform it.
//
// The compiler may optionally specialize to the individual device
// (not just type of device) indicated by the executor.
//
// Use the overload below to compile computations that run in parallel.
virtual StatusOr<std::unique_ptr<Executable>> Compile(
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* executor) = 0;
// Compiles a set of HLO modules that can run in parallel, potentially
// communicating data between the modules, and returns a corresponding
// sequence of executable objects.
//
// TODO(b/68666782): Remove this method after adding support for multiple
// modules to RunHloPasses and RunBackends.
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::vector<std::unique_ptr<HloModule>> modules,
std::vector<std::vector<perftools::gputools::StreamExecutor*>>

File diff suppressed because it is too large Load Diff

View File

@ -25,12 +25,25 @@ limitations under the License.
namespace xla {
// HLO pass which inserts a copy of the root instruction (creating a new root)
// if the root is or points-to any constant or parameter instruction.
// If the root instruction is a Tuple, only tuple elements which point to
// constant or parameter instructions will be copied.
// Copy insertion is necessary because constant and parameter arrays have
// different lifetimes than computation results.
// Copy insertion is a legalization HLO pass which inserts copies (kCopy
// instructions) to eliminate several kinds of problems in the HLO module.
//
// (1) Entry parameter or a constant live out of the entry computation. Entry
// computation arguments and constants have different lifetimes than the
// computation result and cannot share the same allocation. Parameters and
// constants live out of non-entry computations do not need copies.
//
// (2) Different values which are simultaneously live and which must be held
// in the same buffer. This can occur in while bodies. Specifically, the
// while loop state (the arguments to the while instruction) is updated
// in-place and the update may clobber the value from the previous
// iteration before the previous value is dead. Computations called from
// kCall instructions do not need such copies because kCall has no update
// in-place semantics.
//
// (3) The buffer set of the root instruction of the entry computation must be
// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
// InstructionAliasSet::IsDistinct return true.
class CopyInsertion : public HloPassInterface {
public:
tensorflow::StringPiece name() const override { return "copy-insertion"; }
@ -39,14 +52,16 @@ class CopyInsertion : public HloPassInterface {
// (copies were inserted).
StatusOr<bool> Run(HloModule* module) override;
protected:
// Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
// duplicate copies.
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
// A map containing all copies inserted during the copy insertion pass. The
// key is the copied instruction and the value is the copy.
tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_;
// The CPU and GPU backend need additional copies added due to deficiencies in
// buffer assignment. Specifically, copies are needed for constants live-out
// of computations, and for values which are live-in and live-out of the same
// computation. These copies are needed because buffer-assignment uses a
// computation-scoped analyis (TuplePointsToAnalysis) and has limited
// visibility across computation boundaries. This method adds these necessary
// copies. Returns whether the module was modified.
//
// TODO(b/62548313): Remove this when buffer assignment is module-scoped.
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
};
} // namespace xla

File diff suppressed because it is too large Load Diff

View File

@ -79,6 +79,7 @@ cc_library(
deps = [
":compiler_functor",
":conv_canonicalization",
":cpu_copy_insertion",
":cpu_executable",
":cpu_instruction_fusion",
":cpu_options",
@ -103,7 +104,6 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
@ -273,6 +273,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:ops",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"@llvm//:code_gen",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
@ -750,6 +751,38 @@ cc_library(
],
)
cc_library(
name = "cpu_copy_insertion",
srcs = ["cpu_copy_insertion.cc"],
hdrs = ["cpu_copy_insertion.h"],
deps = [
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "cpu_copy_insertion_test",
srcs = ["cpu_copy_insertion_test.cc"],
deps = [
":cpu_copy_insertion",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -53,7 +53,7 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
// kernel and output.
//
// For simplicity, as a first step, we reshape the input and filter to
// NHWC and HWIO order, respectively. This may lose precision but not
// NHWC and HWIO order, respectively. This may lose precision but won't
// break the soundness.
HloInstruction* input = hlo->mutable_operand(0);
@ -98,14 +98,18 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
HloInstruction::CreateTranspose(new_kernel_shape, kernel,
new_kernel_dim_order));
std::vector<int64> new_output_dim_order(num_dims);
std::vector<int64> new_conv_dims(num_dims);
auto output_batch_dim = dnums.output_batch_dimension();
auto output_feature_dim = dnums.output_feature_dimension();
new_output_dim_order[0] = output_batch_dim;
new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim);
for (int i = 0; i < num_spatial_dims; ++i) {
new_output_dim_order[i + 1] = dnums.spatial_dimensions(i);
new_conv_dims[i + 1] =
hlo->shape().dimensions(dnums.spatial_dimensions(i));
}
new_output_dim_order[num_dims - 1] = output_feature_dim;
new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim);
Shape new_conv_shape =
ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims);
@ -129,14 +133,11 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
hlo->window(), new_dnums));
// kConvolution inherits the dimension mapping of its input, so we need to
// reshape the output back to the shape of the original convolution. This
// is done by apply the inverse permutation of the collapsing order of the
// input reshape.
// Reshape the output back to the shape of the original convolution.
TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
hlo, HloInstruction::CreateTranspose(
hlo->shape(), new_conv,
InversePermutation(new_input_dim_order))));
InversePermutation(new_output_dim_order))));
changed = true;
}
}

View File

@ -46,9 +46,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
@ -332,15 +332,16 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
// (and sometime after) copy insertion, to avoid dead code from interfering
// with the rewrites.
pipeline.AddPass<HloDCE>();
pipeline.AddPass<CopyInsertion>();
pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CpuCopyInsertion>();
if (options::CpuParallelBackendRequested(module->config())) {
// Re-run the outlining, in case any copies were inserted into the entry
// computation.
pipeline.AddPass<ParallelizationPreparation>(max_parallelism,
ShapeSizeBytesFunction());
pipeline.AddPass<CpuCopyInsertion>();
}
pipeline.AddPass<HloDCE>();
pipeline.AddPass<FlattenCallGraph>();
return pipeline.Run(module).status();
}
@ -426,11 +427,25 @@ Status InitializeModuleHooks(
} // namespace
StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* /*stream_exec*/) {
VLOG(2) << "Before optimization:";
XLA_VLOG_LINES(2, module->ToString());
TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false));
VLOG(2) << "After optimization:";
XLA_VLOG_LINES(2, module->ToString());
return std::move(module);
}
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* stream_exec) {
const string timer_message =
"Compiling [" + module->name() + "] for CPU using JIT";
ScopedLoggingTimer compiling_timer(timer_message, 1);
XLA_SCOPED_LOGGING_TIMER(timer_message);
VLOG(1) << "Compiling: " << module->name();
TF_RET_CHECK(stream_exec != nullptr);
@ -458,14 +473,6 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
llvm_module->setDataLayout(jit->data_layout());
llvm_module->setTargetTriple(jit->target_triple().getTriple());
VLOG(2) << "Before optimization:";
XLA_VLOG_LINES(2, module->ToString());
TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false));
VLOG(2) << "After optimization:";
XLA_VLOG_LINES(2, module->ToString());
HloComputation* computation = module->entry_computation();
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
if (module->config().hlo_profiling_enabled()) {
@ -537,11 +544,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
parallel_computations.emplace(to_apply, instruction);
}
size_t entry_computation_profile_idx = hlo_to_profile_idx.size();
IrEmitter ir_emitter(
*module, *assignment, llvm_module.get(), std::move(hlo_to_profile_idx),
/*entry_computation_profile_idx=*/entry_computation_profile_idx,
jit->target_machine(), jit->external_constant_pool());
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
hlo_to_profile_idx, hlo_to_profile_idx.size(),
jit->target_machine(), jit->external_constant_pool());
std::unique_ptr<HloInstructionMap<string>> function_names(
new HloInstructionMap<string>());
@ -619,11 +624,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// before the entry computation. The order of computations returned from
// GetEmbeddedComputations guarantees that a called computation occurs
// before a caller computation.
size_t entry_computation_profile_idx = hlo_to_profile_idx.size();
IrEmitter ir_emitter(
*module, *assignment, llvm_module.get(), std::move(hlo_to_profile_idx),
/*entry_computation_profile_idx=*/entry_computation_profile_idx,
jit->target_machine(), jit->external_constant_pool());
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
hlo_to_profile_idx, hlo_to_profile_idx.size(),
jit->target_machine(), jit->external_constant_pool());
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {

View File

@ -116,7 +116,11 @@ class CpuCompiler : public LLVMCompiler {
// stream_execs)
using LLVMCompiler::Compile;
StatusOr<std::unique_ptr<Executable>> Compile(
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* stream_exec) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* stream_exec) override;

View File

@ -0,0 +1,43 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include <memory>
#include <set>
#include <vector>
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
StatusOr<bool> CpuCopyInsertion::Run(HloModule* module) {
CopyInsertion generic_copy_insertion;
TF_ASSIGN_OR_RETURN(bool generic_changed, generic_copy_insertion.Run(module));
// The CPU backend needs additional copies added due to deficiencies in
// buffer assignment.
TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed,
CopyInsertion::AddCopiesForBufferAssignment(module));
return generic_changed || buffer_assignment_changed;
}
} // namespace xla

View File

@ -0,0 +1,42 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// Besides the modifications made by the generic xla::CopyInsertion, this
// CPU-specific copy insertion pass also adds copies to values live out of
// computations satisfying certain conditions (defined by constant or parameter,
// etc). This is necessary because of deficiencies of buffer
// assignment. Specifically, buffer assignment is computation-scoped and does
// not recognized aliasing between arguments and outputs of computations.
//
// TODO(b/62548313): Remove this when buffer assignment is smarter
// (module-scoped).
class CpuCopyInsertion : public HloPassInterface {
public:
tensorflow::StringPiece name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_

View File

@ -0,0 +1,139 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace xla {
namespace {
namespace op = xla::testing::opcode_matchers;
int64 CountCopies(const HloComputation& computation) {
int64 count = 0;
for (const auto& instruction : computation.instructions()) {
if (instruction->opcode() == HloOpcode::kCopy) {
count++;
}
}
return count;
}
int64 CountCopies(const HloModule& module) {
int64 count = 0;
for (const auto& computation : module.computations()) {
count += CountCopies(*computation);
}
return count;
}
class CpuCopyInsertionTest : public HloTestBase {
protected:
void InsertCopies(HloModule* module) {
CpuCopyInsertion copy_insertion;
ASSERT_IS_OK(copy_insertion.Run(module).status());
}
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
};
TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
// Test a while body and condition which are each simply a constant (root of
// computation is a constant). Each constant should be copied.
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto param_0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
auto body_builder = HloComputation::Builder("body");
body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
body_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
module->AddEmbeddedComputation(cond_builder.Build());
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
module->AddEntryComputation(builder.Build());
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 3);
EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
EXPECT_THAT(condition->root_instruction(), op::Copy(op::Constant()));
}
TEST_F(CpuCopyInsertionTest, TupleCall) {
// Test a kCall instruction which calls a computation which produces a three
// element tuple: one is a constant, one is a parameter, and one is produced
// in the computation. The constant and parameter should be copied.
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_});
auto sub_builder = HloComputation::Builder("subcomputation");
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto constant = sub_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, sub_param, constant));
sub_builder.AddInstruction(
HloInstruction::CreateTuple({sub_param, constant, add}));
HloComputation* subcomputation =
module->AddEmbeddedComputation(sub_builder.Build());
builder.AddInstruction(
HloInstruction::CreateCall(tuple_shape, {param}, subcomputation));
module->AddEntryComputation(builder.Build());
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*subcomputation), 2);
EXPECT_THAT(subcomputation->root_instruction(),
op::Tuple(op::Copy(op::Parameter()), op::Copy(op::Constant()),
op::Add()));
}
} // namespace
} // namespace xla

View File

@ -147,8 +147,9 @@ Status CpuExecutable::ExecuteComputeFunction(
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
std::vector<se::DeviceMemoryBase> argument_buffers;
for (int i = 0; i < arguments.size(); ++i) {
argument_buffers.push_back(arguments[i]->buffer(/*index=*/{}));
argument_buffers.reserve(arguments.size());
for (const auto* argument : arguments) {
argument_buffers.push_back(argument->buffer(/*index=*/{}));
}
return ExecuteComputeFunction(run_options, argument_buffers, buffers,
hlo_execution_profile);

View File

@ -26,14 +26,14 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Target/TargetRegisterInfo.h"
#include "llvm/Target/TargetSubtargetInfo.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
@ -795,7 +795,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// operand index is within the bounds. The unsigned comparison includes
// checking whether the operand index >= 0.
llvm_ir::IrArray::Index operand_index(source_index.size());
llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
llvm::Value* in_bounds_condition = ir_builder_.getTrue();
for (int64 i = 0; i < rank; ++i) {
llvm::Value* strided_index = ir_builder_.CreateNSWMul(
source_index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
@ -1140,7 +1140,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0));
};
llvm::Value* in_bounds_condition = nullptr;
llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
for (int i = 0; i < num_spatial_dims; ++i) {
llvm::ConstantInt* input_bound =
ir_builder_.getInt64(window_util::DilatedBound(
@ -1153,9 +1153,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
llvm::Value* dim_ok =
ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole);
in_bounds_condition =
in_bounds_condition
? ir_builder_.CreateAnd(in_bounds_condition, dim_ok)
: dim_ok;
ir_builder_.CreateAnd(in_bounds_condition, dim_ok);
}
// Now we need to map the dilated base coordinates back to the actual

View File

@ -86,6 +86,9 @@ class DfsHloVisitorBase {
virtual Status HandleConvert(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
virtual Status HandleBitcastConvert(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
virtual Status HandleCopy(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
@ -208,6 +211,7 @@ class DfsHloVisitorBase {
virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0;
virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
virtual Status HandlePad(HloInstructionPtr hlo) = 0;

View File

@ -167,6 +167,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleWhile(HloInstructionPtr xla_while) override {
return DefaultAction(xla_while);
}
Status HandleConditional(HloInstructionPtr conditional) override {
return DefaultAction(conditional);
}
Status HandleRecv(HloInstructionPtr recv) override {
return DefaultAction(recv);
}

View File

@ -110,6 +110,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
PrimitiveType_Name(from_type).c_str(),
PrimitiveType_Name(to_type).c_str());
}
case HloOpcode::kBitcastConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
CHECK(primitive_util::IsIntegralType(from_type));
if (from_type == to_type) {
return operand_value;
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
return ir_builder_->CreateBitCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
"bit-widths (%u versus %u) ",
PrimitiveType_Name(from_type).c_str(),
PrimitiveType_Name(to_type).c_str(),
primitive_util::BitWidth(from_type),
primitive_util::BitWidth(to_type));
}
case HloOpcode::kAbs: {
bool is_signed =
primitive_util::IsSignedIntegralType(op->shape().element_type());
@ -203,6 +223,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
PrimitiveType_Name(from_type).c_str(),
PrimitiveType_Name(to_type).c_str());
}
case HloOpcode::kBitcastConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
CHECK(primitive_util::IsFloatingPointType(from_type));
if (from_type == to_type) {
return operand_value;
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
return ir_builder_->CreateBitCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
"bit-widths (%u versus %u) ",
PrimitiveType_Name(from_type).c_str(),
PrimitiveType_Name(to_type).c_str(),
primitive_util::BitWidth(from_type),
primitive_util::BitWidth(to_type));
}
case HloOpcode::kExp:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value},
{operand_value->getType()},
@ -1073,6 +1113,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kConvert:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
@ -1081,11 +1122,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kTanh:
case HloOpcode::kNot:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
@ -1094,6 +1135,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
return EmitUnaryOp(hlo, operand_value);
};
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
@ -1106,14 +1148,13 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kOr:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSubtract:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
const HloInstruction* lhs = hlo->operand(0);

View File

@ -343,15 +343,16 @@ tf_cc_test(
)
cc_library(
name = "copy_insertion",
srcs = ["copy_insertion.cc"],
hdrs = ["copy_insertion.h"],
name = "gpu_copy_insertion",
srcs = ["gpu_copy_insertion.cc"],
hdrs = ["gpu_copy_insertion.h"],
deps = [
":ir_emission_utils",
"//tensorflow/compiler/xla/service:call_graph",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:logical_buffer",
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
],
)
@ -427,8 +428,8 @@ cc_library(
hdrs = ["gpu_compiler.h"],
deps = [
":convolution_folding",
":copy_insertion",
":fusion_merger",
":gpu_copy_insertion",
":gpu_executable",
":hlo_schedule",
":instruction_fusion",
@ -574,11 +575,14 @@ tf_cc_test(
deps = [
":instruction_fusion",
":while_transformer",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)

View File

@ -258,22 +258,19 @@ tensorflow::Status ConvolutionThunk::Convolve(
}
std::vector<AlgorithmDesc> ConvolutionThunk::GetAlgorithms(
se::StreamExecutor* stream_exec) const {
bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const {
std::vector<AlgorithmDesc> algorithms;
// TODO(yangzihao): Currently disable the use of winograd nonfused in XLA
// by default. Should send in conv parameters and enable it when
// ShouldIncludeWinogradNonfusedAlgo() returns true.
switch (convolution_kind_) {
case ConvolutionKind::kBackwardFilter:
CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
/*with_winograd_nonfused=*/false, &algorithms));
with_winograd_nonfused, &algorithms));
break;
case ConvolutionKind::kBackwardInput:
CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
/*with_winograd_nonfused=*/false, &algorithms));
with_winograd_nonfused, &algorithms));
break;
case ConvolutionKind::kForward:
CHECK(stream_exec->GetConvolveAlgorithms(/*with_winograd_nonfused=*/false,
CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused,
&algorithms));
break;
}
@ -287,6 +284,26 @@ static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) {
return tensorflow::strings::StrCat(algo.algo_id());
}
// Determines whether we can safely perform a winograd non-fused convolution for
// the given input and output descriptors. This works around b/68264959, an
// integer overflow in cuDNNv5 and cuDNNv6.
static bool ShouldIncludeWinogradNonfusedAlgo(
const BatchDescriptor& input_descriptor,
const BatchDescriptor& output_descriptor) {
int64 batch = input_descriptor.count();
int64 in_depths = input_descriptor.feature_map_count();
int64 in_rows = input_descriptor.height();
int64 in_cols = input_descriptor.width();
int64 out_depths = output_descriptor.feature_map_count();
int64 total_size = 16 * std::ceil(batch / 16.0) *
std::max(in_depths, out_depths) * in_cols * in_rows *
sizeof(float);
int64 threshold = 1L << 31;
return total_size < threshold;
}
tensorflow::Status ConvolutionThunk::ConvolveWithTune(
const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data,
const FilterDescriptor& filter_descriptor,
@ -303,9 +320,13 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
"ConvolutionThunk: "
<< this;
bool with_winograd_nonfused =
ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor);
se::dnn::ProfileResult best_result;
se::dnn::ProfileResult best_result_without_scratch;
std::vector<AlgorithmDesc> algorithms = GetAlgorithms(stream->parent());
std::vector<AlgorithmDesc> algorithms =
GetAlgorithms(with_winograd_nonfused, stream->parent());
for (auto algorithm : algorithms) {
ConvolveScratchAllocator scratch_allocator(
buffer_allocations.device_ordinal(),

View File

@ -116,6 +116,7 @@ class ConvolutionThunk : public Thunk {
// Returns the convolve algorithms that can be used for this ConvolutionThunk.
std::vector<perftools::gputools::dnn::AlgorithmDesc> GetAlgorithms(
bool with_winograd_nonfused,
perftools::gputools::StreamExecutor* stream_exec) const;
// Fastest cuDNN convolution algorithm for this thunk learned from

View File

@ -1,71 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h"
#include <memory>
#include <set>
#include <vector>
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace gpu {
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(bool changed, CopyInsertion::Run(module));
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
TuplePointsToAnalysis::Run(module));
// Make sure all operands of a library call are in memory instead of constants
// in IR. The top-level (index {}) of the points-to set of each operand
// indicates the source(s) of the array buffer. If any of these are constant,
// then add a copy to materialize the array.
HloComputation* computation = module->entry_computation();
for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
if (ImplementedAsLibraryCall(*hlo)) {
for (int64 i = 0; i < hlo->operand_count(); ++i) {
HloInstruction* operand = hlo->mutable_operand(i);
const PointsToSet& points_to =
points_to_analysis->GetPointsToSet(operand);
const auto& element = points_to.element(/*index=*/{});
if (std::any_of(element.begin(), element.end(),
[](const LogicalBuffer* buffer_source) {
return buffer_source->instruction()->opcode() ==
HloOpcode::kConstant;
})) {
TF_ASSIGN_OR_RETURN(HloInstruction * copy,
CopyInsertion::FindOrInsertCopy(operand));
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy));
changed = true;
}
}
}
}
return changed;
}
} // namespace gpu
} // namespace xla

View File

@ -33,8 +33,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h"
#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
@ -126,7 +126,7 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) {
// Runs optimization passes on the given HLO module.
tensorflow::Status OptimizeHloModule(
HloModule* hlo_module, const se::DeviceDescription& device_desc,
HloModule* hlo_module,
const HloCostAnalysis::ShapeSizeFunction& shape_size_function) {
{
HloPassPipeline pipeline("optimization");
@ -224,9 +224,8 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
// (and sometime after) copy insertion, to avoid dead code from interfering
// with the rewrites.
pipeline.AddPass<HloDCE>();
pipeline.AddPass<GpuCopyInsertion>();
pipeline.AddPass<HloDCE>();
pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<GpuCopyInsertion>();
return pipeline.Run(hlo_module).status();
}
@ -297,19 +296,23 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
GpuCompiler::GpuCompiler()
: pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {}
StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/) {
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
Tracing::TraceMe annotation("HLO Transforms", module->name(),
/*is_expensive=*/true);
TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), ShapeSizeBytesFunction()));
return std::move(module);
}
StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
TF_RET_CHECK(stream_exec != nullptr);
{
Tracing::TraceMe annotation("HLO Transforms", module->name(),
/*is_expensive=*/true);
TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(),
stream_exec->GetDeviceDescription(),
ShapeSizeBytesFunction()));
TF_RETURN_IF_ERROR(
PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
}
TF_RETURN_IF_ERROR(
PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
llvm::LLVMContext llvm_context;
std::string buffer;
@ -362,8 +365,11 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
HloComputation* entry_computation = module->entry_computation();
IrEmitterUnnested ir_emitter(module->config(), entry_computation,
&ir_emitter_context);
TF_RETURN_IF_ERROR(
entry_computation->root_instruction()->Accept(&ir_emitter));
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
TF_RETURN_IF_ERROR(
entry_computation->root_instruction()->Accept(&ir_emitter));
}
if (user_pre_optimization_hook_) {
TF_CHECK_OK(user_pre_optimization_hook_(llvm_module));
@ -412,9 +418,12 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
cc_minor = 0;
}
TF_ASSIGN_OR_RETURN(string ptx,
CompileToPtx(&llvm_module, {cc_major, cc_minor},
module->config(), libdevice_dir));
string ptx;
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - CompileToPtx");
TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor},
module->config(), libdevice_dir));
}
if (!ir_dump_directory.empty()) {
TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory(
@ -470,6 +479,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx,
int cc_major,
int cc_minor) {
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult");
Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true);
bool inserted;
decltype(compilation_cache_.begin()) iter;

View File

@ -49,7 +49,11 @@ class GpuCompiler : public LLVMCompiler {
// stream_execs)
using LLVMCompiler::Compile;
StatusOr<std::unique_ptr<Executable>> Compile(
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* stream_exec) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* stream_exec) override;

View File

@ -0,0 +1,112 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
#include <memory>
#include <set>
#include <vector>
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace gpu {
StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
HloInstruction* hlo) {
HloInstruction*& copy = inserted_copies_[hlo];
if (copy == nullptr) {
TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
}
return copy;
}
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
CopyInsertion generic_copy_insertion;
TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
HloDataflowAnalysis::Run(module));
// Make sure all operands of a library call are in memory instead of constants
// in IR.
for (HloInstruction* hlo :
module->entry_computation()->MakeInstructionPostOrder()) {
if (ImplementedAsLibraryCall(*hlo)) {
for (int64 i = 0; i < hlo->operand_count(); ++i) {
HloInstruction* operand = hlo->mutable_operand(i);
TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
const auto& values = dataflow->GetValueSet(operand).values();
if (std::any_of(values.begin(), values.end(),
[](const HloValue* value) {
return value->defining_instruction()->opcode() ==
HloOpcode::kConstant;
})) {
TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand));
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy));
changed = true;
}
}
}
}
// Init values of a while node cannot be constants. Insert copies for any
// constants found at the operand of a while.
tensorflow::gtl::FlatSet<HloInstruction*> copied_constants;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kWhile) {
continue;
}
for (auto& pair :
dataflow->GetInstructionValueSet(instruction->operand(0))) {
const HloValueSet& value_set = pair.second;
for (const HloValue* value : value_set.values()) {
if (value->defining_instruction()->opcode() ==
HloOpcode::kConstant &&
!ContainsKey(copied_constants, value->defining_instruction())) {
HloInstruction* constant = value->defining_instruction();
TF_ASSIGN_OR_RETURN(HloInstruction * copy,
FindOrInsertCopy(constant));
TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
copied_constants.insert(constant);
changed = true;
}
}
}
}
}
// The GPU backend needs additional copies added due to deficiencies in
// buffer assignment.
TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed,
CopyInsertion::AddCopiesForBufferAssignment(module));
return changed || buffer_assignment_changed;
}
} // namespace gpu
} // namespace xla

View File

@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace gpu {
@ -25,12 +25,23 @@ namespace gpu {
// Besides the modifications made by the generic xla::CopyInsertion, this
// GPU-specific copy insertion also materializes operands of library calls by
// inserting kCopy instructions.
class GpuCopyInsertion : public CopyInsertion {
class GpuCopyInsertion : public HloPassInterface {
public:
tensorflow::StringPiece name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
protected:
// Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
// duplicate copies.
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
// A map containing all copies inserted to materialize operands of library
// calls. The key is the copied instruction and the value is the copy.
tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_

View File

@ -492,9 +492,8 @@ StatusOr<string> CompileToPtx(llvm::Module* module,
tensorflow::port::Tracing::TraceMe annotation(
"Compiling IR", llvm_ir::AsString(module->getName()),
/*is_expensive=*/true);
ScopedLoggingTimer compilation_timer(
"Compile module " + llvm_ir::AsString(module->getName()),
/*vlog_level=*/2);
XLA_SCOPED_LOGGING_TIMER("Compile module " +
llvm_ir::AsString(module->getName()));
TF_ASSIGN_OR_RETURN(
ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config,
libdevice_dir_path));

View File

@ -17,9 +17,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@ -33,8 +36,6 @@ class WhileTransformerTest : public HloTestBase {
: module_(CreateNewModule()),
induction_variable_shape_(ShapeUtil::MakeShape(S32, {})),
data_shape_(ShapeUtil::MakeShape(F32, {8})),
loop_state_shape_(ShapeUtil::MakeTupleShape(
{induction_variable_shape_, data_shape_})),
condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {}
std::unique_ptr<HloComputation> BuildConditionComputation(
@ -42,8 +43,8 @@ class WhileTransformerTest : public HloTestBase {
auto builder = HloComputation::Builder(TestName() + ".Condition");
auto limit_const = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<int32>(limit)));
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
0, GetLoopStateShape(tuple_index), "loop_state"));
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
limit_const->shape(), loop_state, tuple_index));
@ -58,8 +59,8 @@ class WhileTransformerTest : public HloTestBase {
const int64 increment) {
auto builder = HloComputation::Builder(TestName() + ".Body");
// Create param instruction to access loop state.
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
0, GetLoopStateShape(ind_var_tuple_index), "loop_state"));
// Update the induction variable GTE(ind_var_tuple_index).
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
@ -73,7 +74,7 @@ class WhileTransformerTest : public HloTestBase {
data_shape_, loop_state, data_tuple_index));
// Use 'induction_variable' in computation with no path to output tuple.
auto update = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {}));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data, update));
// Create output Tuple.
@ -98,8 +99,9 @@ class WhileTransformerTest : public HloTestBase {
HloInstruction::CreateTuple({induction_var_init, data_init}))
: builder.AddInstruction(
HloInstruction::CreateTuple({data_init, induction_var_init}));
auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape_, condition, body, loop_state_init));
auto while_hlo = builder.AddInstruction(
HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index),
condition, body, loop_state_init));
module_->AddEntryComputation(builder.Build());
return while_hlo;
}
@ -115,18 +117,34 @@ class WhileTransformerTest : public HloTestBase {
}
void RunCopyInsertionPass() {
HloVerifier verifier([](const Shape& shape) {
return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
});
TF_ASSERT_OK(verifier.Run(module_.get()).status());
CopyInsertion copy_insertion;
EXPECT_IS_OK(copy_insertion.Run(module_.get()).status());
TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
}
Shape GetLoopStateShape(const int64 ind_var_tuple_index) {
if (ind_var_tuple_index == 0) {
return ShapeUtil::MakeTupleShape(
{induction_variable_shape_, data_shape_});
} else {
return ShapeUtil::MakeTupleShape(
{data_shape_, induction_variable_shape_});
}
}
std::unique_ptr<HloModule> module_;
Shape induction_variable_shape_;
Shape data_shape_;
Shape loop_state_shape_;
Shape condition_result_shape_;
};
TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
// TODO(b/68830972): The while transformer is far too fragile. It patterns
// matches the exact expressions of opcodes. Re-enable when transformation is
// more general
TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) {
// Build computation with induction variable at tuple element 0.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
@ -137,13 +155,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
RunCopyInsertionPass();
// Run WhileTransformer.
auto result = gpu::CanTransformWhileToFor(while_hlo);
ASSERT_TRUE(result.ok());
TF_ASSERT_OK(result.status());
// Check results.
EXPECT_THAT(result.ConsumeValueOrDie(),
Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
}
TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
// TODO(b/68830972): The while transformer is far too fragile. It patterns
// matches the exact expressions of opcodes. Re-enable when transformation is
// more general
TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) {
// Build computation with induction variable at tuple element 1.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(1, 10));
@ -154,13 +175,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
RunCopyInsertionPass();
// Run WhileTransformer.
auto result = gpu::CanTransformWhileToFor(while_hlo);
ASSERT_TRUE(result.ok());
TF_ASSERT_OK(result.status());
// Check results.
EXPECT_THAT(result.ConsumeValueOrDie(),
Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
}
TEST_F(WhileTransformerTest, InvalidLoopLimit) {
// TODO(b/68830972): The while transformer is far too fragile. It patterns
// matches the exact expressions of opcodes. Re-enable when transformation is
// more general
TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) {
// Build computation with invalid loop limit.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(0, 5));
@ -176,7 +200,10 @@ TEST_F(WhileTransformerTest, InvalidLoopLimit) {
HasSubstr("Loop start must be less than loop limit."));
}
TEST_F(WhileTransformerTest, InvalidLoopIncrement) {
// TODO(b/68830972): The while transformer is far too fragile. It patterns
// matches the exact expressions of opcodes. Re-enable when transformation is
// more general
TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) {
// Build computation with invalid loop increment.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));

View File

@ -250,7 +250,3 @@ message HloProto {
HloOrderingProto hlo_ordering = 2;
BufferAssignmentProto buffer_assignment = 3;
}
message HloProtos {
repeated HloProto hlo_protos = 1;
}

View File

@ -144,8 +144,10 @@ class BufferValueMap {
// Move the given value into the given buffer.
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
buffers_.at(old_buffer_number).erase(&value);
if (buffers_.at(old_buffer_number).empty()) {
tensorflow::gtl::FlatSet<const HloValue*>& old_value_set =
buffers_.at(old_buffer_number);
old_value_set.erase(&value);
if (old_value_set.empty()) {
buffers_.erase(old_buffer_number);
}
@ -175,7 +177,7 @@ class BufferValueMap {
// Value is init of a while (use is while).
std::vector<BufferNumber> aliased_buffers;
for (const HloUse& use : value.uses()) {
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
VLOG(2) << "use of value " << value.ToShortString() << ": " << use;
if (use.instruction->opcode() == HloOpcode::kWhile) {
// Determine the while value that this shares a buffer with.
const HloValue& while_value =
@ -411,7 +413,7 @@ string HloAliasAnalysis::ToString() const {
/* static */
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
HloModule* module) {
VLOG(1) << "HloAliasAnalysis::Run on module " << module->name();
VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
XLA_VLOG_LINES(2, module->ToString());
auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
@ -444,7 +446,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
TF_DCHECK_OK(alias_analysis->Verify());
XLA_VLOG_LINES(1, alias_analysis->ToString());
XLA_VLOG_LINES(2, alias_analysis->ToString());
return std::move(alias_analysis);
}

View File

@ -407,16 +407,18 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
HloModule* module, const HloComputationProto& proto,
tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map,
const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
const std::function<void(std::unique_ptr<HloComputation>)>&
add_fused_computation,
HloInstruction* fusion_instruction) {
std::vector<std::unique_ptr<HloInstruction>> instructions;
tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloInstruction> instruction,
HloInstruction::CreateFromProto(module, instruction_proto,
instruction_map, computation_map));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
HloInstruction::CreateFromProto(
module, instruction_proto, instruction_map,
computation_map, add_fused_computation));
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}

View File

@ -152,12 +152,16 @@ class HloComputation {
// computation_map: a map from computation name to HloComputation*. This map
// must contain all computations which the newly constructed computation
// calls.
// fusion_instruction: if non-null then the newly created computation will be
// constructed as a fused computation with this instruction as its fusion
// parent.
// add_fused_computation: A function to call to add a fused
// computation. Used only when the instruction is a fusion instruction.
// fusion_instruction: if non-null then the newly created computation will
// be constructed as a fused computation with this instruction as its
// fusion parent.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
HloModule* module, const HloComputationProto& proto,
tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map,
const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
const std::function<void(std::unique_ptr<HloComputation>)>&
add_fused_computation,
HloInstruction* fusion_instruction = nullptr);
// Gets the instructions in this computation.

View File

@ -22,13 +22,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace xla {
constexpr char HloCostAnalysis::kFlopsKey[];
constexpr char HloCostAnalysis::kTranscendentalsKey[];
constexpr char HloCostAnalysis::kBytesAccessedKey[];
constexpr char HloCostAnalysis::kSecondsKey[];
constexpr char HloCostAnalysis::kOptimalSecondsKey[];
HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
: HloCostAnalysis(shape_size, {}) {}
@ -60,16 +61,16 @@ Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
if (current_should_compute_bottleneck_time_) {
// Compute the time as the time of the bottleneck, i.e. the slowest property
// given the per-second rate of each property.
float max_seconds = 0.0f;
float optimal_seconds = 0.0f;
for (const auto& property : current_properties_) {
if (property.first != kSecondsKey) {
max_seconds = std::max(
max_seconds,
if (property.first != kOptimalSecondsKey) {
optimal_seconds = std::max(
optimal_seconds,
property.second /
GetProperty(property.first, per_second_rates_, INFINITY));
}
}
current_properties_[kSecondsKey] = max_seconds;
current_properties_[kOptimalSecondsKey] = optimal_seconds;
}
TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
@ -480,6 +481,25 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
return Status::OK();
}
Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
// Compute the cost of the true and false computations and take the maximum
// from those for each property.
TF_ASSIGN_OR_RETURN(const Properties true_computation_properties,
ProcessSubcomputation(conditional->true_computation()));
TF_ASSIGN_OR_RETURN(const Properties false_computation_properties,
ProcessSubcomputation(conditional->false_computation()));
current_properties_ = true_computation_properties;
for (const auto& property : false_computation_properties) {
if (!tensorflow::gtl::InsertIfNotPresent(&current_properties_, property)) {
current_properties_[property.first] =
std::max(current_properties_[property.first], property.second);
}
}
current_should_compute_bottleneck_time_ = false;
return Status::OK();
}
Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
return Status::OK();
}
@ -496,8 +516,8 @@ float HloCostAnalysis::bytes_accessed() const {
return GetProperty(kBytesAccessedKey, properties_sum_);
}
float HloCostAnalysis::seconds() const {
return GetProperty(kSecondsKey, properties_sum_);
float HloCostAnalysis::optimal_seconds() const {
return GetProperty(kOptimalSecondsKey, properties_sum_);
}
int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
@ -512,8 +532,8 @@ int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
}
float HloCostAnalysis::seconds(const HloInstruction& hlo) const {
return GetPropertyForHlo(hlo, kSecondsKey, hlo_properties_);
float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
}
StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(

View File

@ -42,7 +42,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
static constexpr char kFlopsKey[] = "flops";
static constexpr char kTranscendentalsKey[] = "transcendentals";
static constexpr char kBytesAccessedKey[] = "bytes accessed";
static constexpr char kSecondsKey[] = "seconds";
static constexpr char kOptimalSecondsKey[] = "optimal_seconds";
// shape_size is a function which returns the size in bytes of the top-level
// buffer of a shape.
@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleReshape(const HloInstruction* reshape) override;
Status HandleTranspose(const HloInstruction* transpose) override;
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
Status FinishVisit(const HloInstruction* root) override;
Status Preprocess(const HloInstruction* hlo) override;
@ -118,14 +119,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
float flop_count() const;
float transcendental_count() const;
float bytes_accessed() const;
float seconds() const;
float optimal_seconds() const;
// Returns the respective cost computed for a particular HLO instruction, or 0
// if the HLO was not found to have a cost in the analysis.
int64 flop_count(const HloInstruction& hlo) const;
int64 transcendental_count(const HloInstruction& hlo) const;
int64 bytes_accessed(const HloInstruction& hlo) const;
float seconds(const HloInstruction& hlo) const;
float optimal_seconds(const HloInstruction& hlo) const;
const Properties& properties() const { return properties_sum_; }
const float property(const string& key) const {

View File

@ -389,7 +389,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
static_assert(bytes_accessed == 64, "");
EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed);
EXPECT_EQ(fusion_analysis.seconds(), 1 << i);
EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i);
}
}

View File

@ -37,6 +37,9 @@ namespace xla {
StatusOr<bool> HloDCE::Run(HloModule* module) {
bool changed = false;
VLOG(2) << "Before dce:";
XLA_VLOG_LINES(2, module->ToString());
for (auto* computation : module->MakeNonfusionComputations()) {
std::unordered_set<HloInstruction*> live_instructions;
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
@ -58,6 +61,8 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
}
for (HloInstruction* dead_root : dead_roots) {
VLOG(1) << "Removing dead root " << dead_root->ToString()
<< " and it's unused operands";
TF_RETURN_IF_ERROR(
computation->RemoveInstructionAndUnusedOperands(dead_root));
changed = true;
@ -87,6 +92,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
}
}
VLOG(2) << "After dce:";
XLA_VLOG_LINES(2, module->ToString());
return changed;
}

View File

@ -335,9 +335,31 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
template <typename NativeT,
typename std::enable_if<
std::is_integral<NativeT>::value &&
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleNot(HloInstruction* not_) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
return ~elem_operand;
}));
return Status::OK();
}
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleNot(HloInstruction* not_) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
return !elem_operand;
}));
return Status::OK();
}
template <typename NativeT,
typename std::enable_if<std::is_same<NativeT, bool>::value>::type* =
nullptr>
Status HandleNot(HloInstruction* not_) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
@ -357,7 +379,24 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return HandleNot<ReturnT>(not_);
}
Status HandleNegate(HloInstruction* negate) override {
template <typename NativeT,
typename std::enable_if<
std::is_signed<NativeT>::value &&
!std::is_floating_point<NativeT>::value>::type* = nullptr>
Status HandleNegate(HloInstruction* negate) {
using type = typename std::make_unsigned<NativeT>::type;
TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate],
ElementWiseUnaryOp(negate, [](ReturnT elem_operand) {
return NativeT(-type(elem_operand));
}));
return Status::OK();
}
template <typename NativeT,
typename std::enable_if<
!std::is_signed<NativeT>::value ||
std::is_floating_point<NativeT>::value>::type* = nullptr>
Status HandleNegate(HloInstruction* negate) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate],
ElementWiseUnaryOp(negate, [](ReturnT elem_operand) {
return -elem_operand;
@ -365,6 +404,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
Status HandleNegate(HloInstruction* negate) override {
return HandleNegate<ReturnT>(negate);
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
@ -402,7 +445,26 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
Status HandleMultiply(HloInstruction* multiply) override {
template <typename NativeT,
typename std::enable_if<
std::is_signed<NativeT>::value &&
!std::is_floating_point<NativeT>::value>::type* = nullptr>
Status HandleMultiply(HloInstruction* multiply) {
using type = typename std::make_unsigned<NativeT>::type;
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[multiply],
ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) {
return NativeT(type(lhs_elem) * type(rhs_elem));
}));
return Status::OK();
}
template <
typename NativeT,
typename std::enable_if<std::is_unsigned<NativeT>::value ||
std::is_floating_point<NativeT>::value ||
is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleMultiply(HloInstruction* multiply) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[multiply],
ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) {
@ -411,6 +473,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
Status HandleMultiply(HloInstruction* multiply) override {
return HandleMultiply<ReturnT>(multiply);
}
Status HandleSubtract(HloInstruction* subtract) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[subtract],
@ -516,9 +582,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return HandleRemainder<ReturnT>(remainder);
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
template <typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
nullptr>
Status HandleAnd(HloInstruction* and_) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[and_],
ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el & rhs_el;
}));
return Status::OK();
}
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleAnd(HloInstruction* and_) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[and_],
@ -539,9 +616,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return HandleAnd<ReturnT>(and_);
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
template <typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
nullptr>
Status HandleOr(HloInstruction* or_) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[or_],
ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el | rhs_el;
}));
return Status::OK();
}
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleOr(HloInstruction* or_) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[or_],
@ -645,7 +733,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleClamp(HloInstruction* clamp) {
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
[](ReturnT low, ReturnT high, ReturnT value) {
[](ReturnT low, ReturnT value, ReturnT high) {
return std::fmax(low, std::fmin(value, high));
};
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
@ -814,13 +902,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
}
rhs_index[dnums.kernel_spatial_dimensions(ki)] =
rhs_spatial_index[ki];
window_dim.window_reversal()
? ((window_dim.size() - 1) - rhs_spatial_index[ki])
: rhs_spatial_index[ki];
}
result_val += lhs_literal.Get<ReturnT>(lhs_index) *
rhs_literal.Get<ReturnT>(rhs_index);
}
cnt:;
cnt : {}
} while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
return result_val;
@ -1287,6 +1377,50 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleSin(HloInstruction* sin) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin],
ElementWiseUnaryOp(sin, [](ReturnT elem_operand) {
return std::sin(elem_operand);
}));
return Status::OK();
}
template <
typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value ||
is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleSin(HloInstruction* sin) {
return InvalidArgument("Unsupported type for Sin");
}
Status HandleSin(HloInstruction* sin) override {
return HandleSin<ReturnT>(sin);
}
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleCos(HloInstruction* cos) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos],
ElementWiseUnaryOp(cos, [](ReturnT elem_operand) {
return std::cos(elem_operand);
}));
return Status::OK();
}
template <
typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value ||
is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleCos(HloInstruction* cos) {
return InvalidArgument("Unsupported type for Cos");
}
Status HandleCos(HloInstruction* cos) override {
return HandleCos<ReturnT>(cos);
}
private:
template <typename IndexT>
StatusOr<std::unique_ptr<Literal>> DynamicSlice(
@ -1397,8 +1531,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const auto* rhs = instruction->operand(1);
const auto* ehs = instruction->operand(2);
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
// removed.
// TODO(b/35950897, b/27796129): add DCHECK back once implicit
// broadcast is removed.
if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) &&
ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) &&
ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) {
@ -1565,6 +1699,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
}
std::vector<HloInstruction*> operands;
operands.reserve(owned_operands.size());
for (auto& operand : owned_operands) {
operands.push_back(operand.get());
}

View File

@ -46,20 +46,57 @@ class HloEvaluatorTest : public HloVerifiedTestBase {
HloEvaluatorTest() { evaluator_ = MakeUnique<HloEvaluator>(); }
std::unique_ptr<HloEvaluator> evaluator_;
void TestUnaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
std::unique_ptr<Literal> input, float aabs = 0) {
HloComputation::Builder b(TestName());
auto c1 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
auto instruction = b.AddInstruction(
HloInstruction::CreateUnary(expected->shape(), opcode, c1));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
auto element_type = expected->shape().element_type();
if (element_type == F32 || element_type == F64) {
ErrorSpec error(aabs);
LiteralTestUtil::ExpectNear(*expected, *result, error);
} else {
LiteralTestUtil::ExpectEqual(*expected, *result);
}
}
void TestBinaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
std::unique_ptr<Literal> lhs,
std::unique_ptr<Literal> rhs) {
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
auto instruction = b.AddInstruction(
HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectEqual(*expected, *result);
}
};
// Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
// with 3 operands.
TEST_F(HloEvaluatorTest, DoesClamp) {
auto low = Literal::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
auto high = Literal::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
auto value = Literal::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
auto high = Literal::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
Shape shape = low->shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
auto instruction = b.AddInstruction(
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
@ -72,6 +109,28 @@ TEST_F(HloEvaluatorTest, DoesClamp) {
LiteralTestUtil::ExpectEqual(*expected, *result);
}
TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
auto low = Literal::CreateR0<float>(0.f);
auto value = Literal::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
auto high = Literal::CreateR0<float>(1.f);
Shape shape = value->shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
auto instruction = b.AddInstruction(
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
LiteralTestUtil::ExpectEqual(*expected, *result);
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
// with 3 operands.
TEST_F(HloEvaluatorTest, DoesSelect) {
@ -103,120 +162,101 @@ TEST_F(HloEvaluatorTest, DoesSelect) {
TEST_F(HloEvaluatorTest, DoesAdd) {
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
auto instruction = b.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1, c2));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
auto expected = Literal::CreateR2<int64>({{3, 4}, {-96, 8}});
LiteralTestUtil::ExpectEqual(*expected, *result);
TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise and with 2 operands.
TEST_F(HloEvaluatorTest, DoesAnd) {
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
auto expected = Literal::CreateR2<int64>({{0, 0}, {4, 4}});
TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise or with 2 operands.
TEST_F(HloEvaluatorTest, DoesOr) {
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
auto expected = Literal::CreateR2<int64>({{3, 4}, {-100, 4}});
TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise multiply with 2 operands.
TEST_F(HloEvaluatorTest, DoesMultiply) {
auto lhs = Literal::CreateR2<int32>({{-1, 0}, {-100, 4}});
auto rhs = Literal::CreateR2<int32>(
{{std::numeric_limits<int32>::min(), 4}, {4, 4}});
auto expected = Literal::CreateR2<int32>(
{{std::numeric_limits<int32>::min(), 0}, {-400, 16}});
TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise divide with 2 operands.
TEST_F(HloEvaluatorTest, DoesDivideInt64) {
auto lhs_s64 = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs_s64 = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2});
HloComputation::Builder b(TestName());
auto c1_s64 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_s64)));
auto c2_s64 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_s64)));
auto instruction = b.AddInstruction(HloInstruction::CreateBinary(
shape_s64, HloOpcode::kDivide, c1_s64, c2_s64));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
auto expected = Literal::CreateR2<int64>({{0, 0}, {-25, 1}});
LiteralTestUtil::ExpectEqual(*expected, *result);
TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
std::move(rhs));
}
TEST_F(HloEvaluatorTest, DoesDivideDouble) {
auto lhs_f64 = Literal::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
auto rhs_f64 = Literal::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2});
HloComputation::Builder b(TestName());
auto c1_f64 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_f64)));
auto c2_f64 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_f64)));
auto instruction = b.AddInstruction(HloInstruction::CreateBinary(
shape_f64, HloOpcode::kDivide, c1_f64, c2_f64));
module().AddEntryComputation(b.Build());
auto result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
auto lhs = Literal::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
auto rhs = Literal::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
auto expected =
Literal::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
LiteralTestUtil::ExpectEqual(*expected, *result);
TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest, DoesAbsR2) {
auto operand = Literal::CreateR2<int64>({{1, -20}, {-100, 4}});
const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2});
HloComputation::Builder b(TestName());
auto c1 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
auto instruction =
b.AddInstruction(HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
auto expected = Literal::CreateR2<int64>({{1, 20}, {100, 4}});
LiteralTestUtil::ExpectEqual(*expected, *result);
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
}
TEST_F(HloEvaluatorTest, DoesAbsR0) {
// For R0 literal.
const Shape& r0 = ShapeUtil::MakeShape(F32, {});
auto operand = Literal::CreateR0<float>(-1.0f);
HloComputation::Builder b(TestName());
auto c1 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
auto instruction =
b.AddInstruction(HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1));
module().AddEntryComputation(b.Build());
auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie();
auto expected = Literal::CreateR0<float>(1.0f);
LiteralTestUtil::ExpectEqual(*expected, *result);
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
}
TEST_F(HloEvaluatorTest, DoesAbsR1WithZeroSize) {
// For R1 literal with dimension of size 0.
Shape empty_r1 = ShapeUtil::MakeShape(F32, {0});
auto operand = Literal::CreateR1<float>({});
HloComputation::Builder b(TestName());
auto c1 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
auto instruction = b.AddInstruction(
HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1));
module().AddEntryComputation(b.Build());
auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie();
auto expected = Literal::CreateR1<float>({});
LiteralTestUtil::ExpectEqual(*expected, *result);
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
}
TEST_F(HloEvaluatorTest, DoesNegateR2) {
auto operand = Literal::CreateR2<int32>(
{{0, std::numeric_limits<int32>::min()}, {-1, 4}});
auto expected =
Literal::CreateR2<int32>({{0, std::numeric_limits<int>::min()}, {1, -4}});
TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
}
TEST_F(HloEvaluatorTest, DoesCosR2) {
auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
auto expected = Literal::CreateR2<float>({{1, -1}, {-1, 1}});
TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand));
}
TEST_F(HloEvaluatorTest, DoesSinR2) {
auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
auto expected = Literal::CreateR2<float>({{0, 0}, {0, 0}});
TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
0x1.0P-20);
}
TEST_F(HloEvaluatorTest, DoesNotR2) {
auto operand =
Literal::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
{-1, std::numeric_limits<int>::max()}});
auto expected =
Literal::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
{0, std::numeric_limits<int>::min()}});
TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
}
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
// constant operands.
TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
@ -794,6 +834,83 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
LiteralTestUtil::ExpectEqual(*expected, *result);
}
TEST_F(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
HloComputation::Builder b(TestName());
// clang-format off
// Input dimensions: [feature=2, height=3, batch=1, width=4]
Array4D<float> input({
{{{1, 2, 3, 4}},
{{5, 6, 7, 8}},
{{9, 10, 11, 12}}},
{{{13, 14, 15, 16}},
{{17, 18, 19, 20}},
{{21, 22, 23, 24}}}
});
// Weight dimensions:
// [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
Array4D<float> weight({{
{{1, 7, 13},
{4, 10, 16}},
{{2, 8, 14},
{5, 11, 17}},
{{3, 9, 15},
{6, 12, 18}}
}});
// clang-format on
auto lhs_literal = Literal::CreateR4FromArray4D<float>(input);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
auto rhs_literal = Literal::CreateR4FromArray4D<float>(weight);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
rhs_instruction->shape(), rhs_instruction, {3, 1}));
Window window;
WindowDimension dim;
dim.set_size(3);
dim.set_stride(1);
dim.set_padding_low(0);
dim.set_padding_high(0);
dim.set_window_dilation(1);
dim.set_base_dilation(1);
dim.set_window_reversal(true);
*window.add_dimensions() = dim;
*window.add_dimensions() = dim;
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(2);
dnums.set_output_batch_dimension(2);
dnums.set_input_feature_dimension(0);
dnums.set_output_feature_dimension(0);
dnums.add_spatial_dimensions(1);
dnums.add_spatial_dimensions(3);
dnums.set_kernel_output_feature_dimension(0);
dnums.set_kernel_input_feature_dimension(2);
dnums.add_kernel_spatial_dimensions(3);
dnums.add_kernel_spatial_dimensions(1);
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
b.AddInstruction(HloInstruction::CreateConvolve(
shape, lhs_instruction, rhs_instruction, window, dnums));
auto computation = module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
Array4D<float> expected_array({{{{2514, 2685}}}});
// clang-format on
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result);
}
TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) {
HloComputation::Builder b(TestName());

View File

@ -83,7 +83,7 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter(
instruction_info->transcendental_count =
cost_analysis.transcendental_count(*hlo);
instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo);
instruction_info->seconds = cost_analysis.seconds(*hlo);
instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo);
instruction_info->profile_index =
hlo_profile_index_map.GetProfileIndexFor(*hlo);
CHECK_LT(instruction_info->profile_index, max_profile_index);

View File

@ -90,10 +90,10 @@ TEST_F(HloExecutionProfileTest, Basic) {
const std::vector<string>& line_3 = lines_and_words[3];
EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles));
EXPECT_EQ(line_2[kInstructionNameIndex], dot_instruction->name());
EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name());
EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles));
EXPECT_EQ(line_3[kInstructionNameIndex], add_instruction->name());
EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name());
}
} // namespace
} // namespace xla

View File

@ -864,9 +864,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
// (eg, parameter).
switch (instr->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kComplex:
@ -882,18 +883,19 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
case HloOpcode::kAnd:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kRng:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
@ -903,7 +905,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kSort:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kRng:
// De-emphasize scalar-shaped elementwise ops -- they're generally
// uninteresting.
if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
@ -911,9 +912,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
}
return kYellow;
case HloOpcode::kBitcast:
case HloOpcode::kTuple:
case HloOpcode::kTrace:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTrace:
case HloOpcode::kTuple:
return kWhite;
case HloOpcode::kBroadcast:
// De-emphasize nodes which broadcast a scalar within a fusion node --
@ -952,28 +953,28 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kRed;
case HloOpcode::kParameter:
return kParameterColor;
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kReduce:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
return kPurple;
case HloOpcode::kMap:
case HloOpcode::kFusion:
case HloOpcode::kMap:
return kGray;
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
return kBrown;
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
case HloOpcode::kWhile:
case HloOpcode::kCall:
return kDarkGreen;
case HloOpcode::kConstant:
LOG(FATAL) << "Constants don't get their own nodes in the graph.";

View File

@ -52,7 +52,9 @@ using ::tensorflow::strings::StrCat;
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
HloModule* module, const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map) {
const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
const std::function<void(std::unique_ptr<HloComputation>)>&
add_fused_computation) {
TF_RET_CHECK(!proto.opcode().empty());
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
@ -78,19 +80,19 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(!proto.fusion_kind().empty());
TF_ASSIGN_OR_RETURN(instruction->fusion_kind_,
StringToFusionKind(proto.fusion_kind()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> fused_computation,
HloComputation::CreateFromProto(
module, proto.fused_instructions_computation(), computation_map,
/*fusion_instruction=*/instruction.get()));
instruction->called_computations_.push_back(
module->AddEmbeddedComputation(std::move(fused_computation)));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> fused_computation,
HloComputation::CreateFromProto(
module, proto.fused_instructions_computation(),
computation_map, add_fused_computation,
/*fusion_instruction=*/instruction.get()));
instruction->called_computations_.push_back(fused_computation.get());
add_fused_computation(std::move(fused_computation));
} else {
for (const string& computation_name : proto.called_computation_names()) {
TF_RET_CHECK(ContainsKey(*computation_map, computation_name))
TF_RET_CHECK(ContainsKey(computation_map, computation_name))
<< "No computation named " << computation_name;
instruction->called_computations_.push_back(
computation_map->at(computation_name));
computation_map.at(computation_name));
}
}
@ -149,7 +151,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
WrapUnique(new HloInstruction(HloOpcode::kParameter, shape));
instruction->parameter_number_ = parameter_number;
instruction->parameter_name_ = name;
instruction->name_ = "%" + name;
instruction->name_ = name;
return instruction;
}
@ -436,6 +438,23 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
const Shape& shape, HloInstruction* pred,
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation) {
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
instruction->AppendOperand(pred);
instruction->AppendOperand(true_computation_arg);
instruction->AppendOperand(false_computation_arg);
// In called_computations_, the index of true_computation must be 0 and that
// of false computation must be 1, as defined by kTrueComputationIndex and
// kFalseComputationIndex.
instruction->called_computations_.push_back(true_computation);
instruction->called_computations_.push_back(false_computation);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
@ -499,6 +518,15 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
return instruction;
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBitcastConvert(const Shape& shape,
HloInstruction* operand) {
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
@ -631,7 +659,10 @@ HloInstruction::CreateSelectAndScatter(
CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
CHECK(std::equal(operand->shape().dimensions().begin(),
operand->shape().dimensions().end(),
Permute(dimensions, shape.dimensions()).begin()));
Permute(dimensions, shape.dimensions()).begin()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< ", operand->shape(): " << ShapeUtil::HumanString(shape)
<< ", dimensions: {" << Join(dimensions, ", ") << "}";
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape));
instruction->AppendOperand(operand);
@ -791,7 +822,7 @@ HloInstruction* HloInstruction::FuseInstructionInternal(
HloInstruction* HloInstruction::CloneAndFuseInternal(
HloInstruction* instruction_to_fuse, bool add_output) {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(instruction_to_fuse->IsFusable());
CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString();
VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
HloInstruction* clone = nullptr;
if (called_computations_.empty()) {
@ -869,10 +900,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// parameter instruction.
int64 param_no = fused_parameters.size();
// Name the parameter after the instruction it represents in the outer
// (non-fusion) computation. Strip the leading "%" from the operand name
// to avoid a double %%.
string param_name =
StrCat(operand->name().substr(1), ".param_", param_no);
// (non-fusion) computation.
string param_name = StrCat(operand->name(), ".param_", param_no);
fused_param = fused_instructions_computation()->AddParameter(
CreateParameter(param_no, operand->shape(), param_name));
AppendOperand(operand);
@ -1013,7 +1042,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
for (const HloInstruction* new_operand : new_operands) {
VLOG(3) << " " << new_operand->name();
VLOG(3) << " %" << new_operand->name();
}
std::unique_ptr<HloInstruction> clone;
@ -1095,6 +1124,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateConvert(shape, new_operands[0]);
break;
case HloOpcode::kBitcastConvert:
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
case HloOpcode::kReducePrecision:
CHECK_EQ(new_operands.size(), 1);
clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_,
@ -1535,6 +1568,7 @@ bool HloInstruction::IdenticalSlowPath(
// A convert result is determined by the primitive type that the operand is
// converted into.
case HloOpcode::kConvert:
case HloOpcode::kBitcastConvert:
return shape().element_type() == other.shape().element_type();
// A reduce-precision operation is determined by the bit sizes.
@ -1814,6 +1848,32 @@ void HloInstruction::set_scatter(HloComputation* computation) {
called_computations_[kScatterComputationIndex] = computation;
}
HloComputation* HloInstruction::true_computation() const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
return called_computations_[kTrueComputationIndex];
}
HloComputation* HloInstruction::false_computation() const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
return called_computations_[kFalseComputationIndex];
}
void HloInstruction::set_true_computation(HloComputation* true_computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kConditional, opcode_);
called_computations_[kTrueComputationIndex] = true_computation;
}
void HloInstruction::set_false_computation(HloComputation* false_computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kConditional, opcode_);
called_computations_[kFalseComputationIndex] = false_computation;
}
string HloInstruction::SignatureString() const {
string operands =
Join(operands_, ", ", [](string* out, HloInstruction* operand) {
@ -1825,7 +1885,7 @@ string HloInstruction::SignatureString() const {
string HloInstruction::ToString(bool compact_operands, bool include_metadata,
bool include_large_constants) const {
string result =
StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
StrCat("%", name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
HloOpcodeString(opcode()), "(",
OperandsToString(compact_operands, include_large_constants), ")");
for (const string& extra : ExtraAttributesToString()) {
@ -1877,7 +1937,7 @@ string HloInstruction::OperandsToString(bool compact,
operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
*out += ShapeUtil::HumanStringWithLayout(operand->shape());
if (!compact) {
StrAppend(out, " ", operand->name());
StrAppend(out, " %", operand->name());
}
});
const int64 remaining = operands_.size() - slice.size();
@ -1896,7 +1956,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
if (CanHaveDimensionsField()) {
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
}
if (window_ != nullptr) {
if (window_ != nullptr && window_->dimensions_size() != 0) {
extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
}
if (padding_config_ != nullptr) {
@ -1964,7 +2024,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
extra.push_back(StrCat("control-predecessors={",
Join(control_predecessors_, ", ",
[](string* out, HloInstruction* pre) {
StrAppend(out, pre->name());
StrAppend(out, "%", pre->name());
}),
"}"));
}
@ -1979,10 +2039,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
}
string HloInstruction::ToShortString() const {
return StrCat(name(), " = ", HloOpcodeString(opcode()), "(",
return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
Join(operands_, ", ",
[](string* out, HloInstruction* operand) {
StrAppend(out, operand->name());
StrAppend(out, "%", operand->name());
}),
")");
}
@ -2076,8 +2136,10 @@ string HloInstruction::ToCategory() const {
bool saw_rank_1 = false;
bool saw_higher_rank = false;
for (const auto* operand : operands()) {
saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1;
saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1;
if (!ShapeUtil::IsTuple(operand->shape())) {
saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1;
saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1;
}
}
if (saw_rank_1 && saw_higher_rank) {
return "rank-1-broadcast binary fusion";
@ -2130,25 +2192,13 @@ bool HloInstruction::IsFusable() const {
if (tracing()) {
return false;
}
// Some kinds of instructions don't make sense to fuse.
switch (opcode_) {
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kTrace:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
return false;
// Only fuse Rng if it is used once, otherwise the random numbers generated
// will be different in each fusion. If it is the root (user count = 0)
// then it is the equivalent of having one user.
case HloOpcode::kRng:
return users_.size() <= 1;
// Side effecting instrutions cannot be fused.
default:
return true;
return !HasSideEffect();
}
}
@ -2199,7 +2249,7 @@ HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
: unique_id_(-1),
opcode_(opcode),
shape_(shape),
name_("%" + HloOpcodeString(opcode)) {
name_(HloOpcodeString(opcode)) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
}
@ -2259,6 +2309,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleConcatenate(this);
case HloOpcode::kConvert:
return visitor->HandleConvert(this);
case HloOpcode::kBitcastConvert:
return visitor->HandleBitcastConvert(this);
case HloOpcode::kCopy:
return visitor->HandleCopy(this);
case HloOpcode::kMultiply:
@ -2345,6 +2397,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleFusion(this);
case HloOpcode::kCall:
return visitor->HandleCall(this);
case HloOpcode::kConditional:
return visitor->HandleConditional(this);
case HloOpcode::kCustomCall:
return visitor->HandleCustomCall(this);
case HloOpcode::kRecv:
@ -2357,7 +2411,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSendDone(this);
// These opcodes are not handled here.
case HloOpcode::kConditional:
case HloOpcode::kTrace:
break;
}
@ -2423,7 +2476,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
visitor->GetVisitState(current_id);
if (visit_state == Visitor::kVisited) {
dfs_stack.pop_back();
VLOG(3) << "Not visiting HLO " << current_node->name()
VLOG(3) << "Not visiting HLO %" << current_node->name()
<< " as it was already visited.";
continue;
}
@ -2432,7 +2485,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
dfs_stack.pop_back();
TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
VLOG(2) << "Visiting HLO " << current_node->name();
VLOG(2) << "Visiting HLO %" << current_node->name();
TF_RETURN_IF_ERROR(current_node->Visit(visitor));
visitor->SetVisitState(current_id, Visitor::kVisited);
TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
@ -2477,7 +2530,7 @@ template <typename HloInstructionPtr>
Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
bool call_finish_visit,
bool ignore_control_predecessors) {
VLOG(3) << "HloInstruction::Accept(" << name() << ")";
VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
TF_RETURN_IF_ERROR(
PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
if (call_finish_visit) {
@ -2493,7 +2546,7 @@ template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
Status HloInstruction::AcceptWithOperandOrder(
DfsHloVisitor* visitor, const CompareFunction& operand_order,
bool call_finish_visit) {
VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")";
VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
InternalCompareFunction func = [&operand_order](
std::pair<int, const HloInstruction*> a,
std::pair<int, const HloInstruction*> b) {
@ -2556,7 +2609,7 @@ Status HloInstruction::Accept(
Status HloInstruction::AcceptOrdered(
DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order) {
VLOG(2) << "HloInstruction::AcceptOrdered(" << name() << ")";
VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")";
TF_RET_CHECK(OrderIsTopologicalSort(order));
// Compute the predecessors of this instruction.
@ -2575,7 +2628,7 @@ Status HloInstruction::AcceptOrdered(
// The visitor can mark instructions as visited to skip particular
// instructions.
if (visitor->DidVisit(*const_instruction)) {
VLOG(3) << "Not visiting HLO " << const_instruction->name()
VLOG(3) << "Not visiting HLO %" << const_instruction->name()
<< " as it was already visited.";
continue;
}
@ -2584,7 +2637,7 @@ Status HloInstruction::AcceptOrdered(
const_cast<HloInstruction*>(const_instruction);
TF_RETURN_IF_ERROR(visitor->Preprocess(instruction));
VLOG(2) << "Visiting HLO " << instruction->name();
VLOG(2) << "Visiting HLO %" << instruction->name();
TF_RETURN_IF_ERROR(instruction->Visit(visitor));
visitor->SetVisited(*instruction);
TF_RETURN_IF_ERROR(visitor->Postprocess(instruction));
@ -2630,6 +2683,7 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kConvert:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
@ -83,12 +84,16 @@ class HloInstruction {
// must contain all operands of the newly constructed instruction.
// computation_map: a map from computation name to HloComputation*. This map
// must contain all computations which the newly constructed instruction
// calls. If the instruction is a fusion instruction, then the fusion
// computation is added to this map and the module.
// calls.
// add_fused_computation: A function to call to add a fused
// computation. Used (clearly) when the instruction is a fusion
// instruction.
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
HloModule* module, const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map);
const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
const std::function<void(std::unique_ptr<HloComputation>)>&
add_fused_computation);
// Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
@ -171,6 +176,11 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
HloInstruction* operand);
// Creates a bitcast conversion instruction, where operand is the data to
// convert and shape is the target shape for the conversion.
static std::unique_ptr<HloInstruction> CreateBitcastConvert(
const Shape& shape, HloInstruction* operand);
// Creates an infeed instruction, which reads data of the given shape from the
// Infeed interface of the device.
static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
@ -305,6 +315,11 @@ class HloInstruction {
HloComputation* body,
HloInstruction* init);
static std::unique_ptr<HloInstruction> CreateConditional(
const Shape& shape, HloInstruction* pred,
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation);
// Creates a fusion instruction. A fusion instruction contains one or more
// fused instructions forming an expression with a single root
// "fused_root". Additional instructions can be added to the fusion
@ -608,6 +623,15 @@ class HloInstruction {
void set_select(HloComputation* select);
void set_scatter(HloComputation* scatter);
// Gets/sets the true and false HloComputation for Conditional. The setters
// should only be called by HloModule or HloComputation methods.
//
// Precondition: The instruction is a Conditional instruction.
HloComputation* true_computation() const;
HloComputation* false_computation() const;
void set_true_computation(HloComputation* true_computation);
void set_false_computation(HloComputation* false_computation);
// Returns a string for the signature of this instruction if considered as a
// function, e.g. the signature of an F32 add is (F32, F32) -> F32.
string SignatureString() const;
@ -1192,6 +1216,10 @@ class HloInstruction {
// kSelectAndScatter computations.
kSelectComputationIndex = 0,
kScatterComputationIndex = 1,
// kConditional computations.
kTrueComputationIndex = 0,
kFalseComputationIndex = 1,
};
// Outfeed configuration information, only present for kOutfeed.

View File

@ -1138,35 +1138,34 @@ TEST_F(HloInstructionTest, CloneSuffixNames) {
// Test cloning the same instruction multiple times.
auto foo =
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo");
EXPECT_EQ(foo->Clone()->name(), "%foo.clone");
EXPECT_EQ(foo->Clone()->Clone()->name(), "%foo.clone2");
EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "%foo.clone3");
EXPECT_EQ(foo->Clone()->name(), "foo.clone");
EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2");
EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3");
// Test custom suffixes.
EXPECT_EQ(foo->Clone("bar")->name(), "%foo.bar");
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "%foo.bar2");
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(),
"%foo.bar2.clone");
EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar");
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2");
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone");
// Test instruction name with a dot.
auto foo_baz = HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "foo.baz");
EXPECT_EQ(foo_baz->Clone()->name(), "%foo.baz.clone");
EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone");
// Test incrementing a large number after the suffix.
auto foo_clone234 = HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "foo.clone234");
EXPECT_EQ(foo_clone234->Clone()->name(), "%foo.clone235");
EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235");
// Test a non-numeric string after the cloning suffix.
auto foo_clonexyz = HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz");
EXPECT_EQ(foo_clonexyz->Clone()->name(), "%foo.clonexyz.clone");
EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone");
// Test a name with multiple appearances of the suffix.
auto foo_clone_clone3 = HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3");
EXPECT_EQ(foo_clone_clone3->Clone()->name(), "%foo.clone.clone4");
EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4");
}
TEST_F(HloInstructionTest, Stringification) {

View File

@ -87,6 +87,7 @@ HLO_MATCHER(Call);
HLO_MATCHER(Ceil);
HLO_MATCHER(Clamp);
HLO_MATCHER(Concatenate);
HLO_MATCHER(Conditional);
HLO_MATCHER(Constant);
HLO_MATCHER(Convert);
HLO_MATCHER(Convolution);

View File

@ -290,9 +290,16 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
for (const HloComputationProto& computation_proto : proto.computations()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
HloComputation::CreateFromProto(
module.get(), computation_proto, &computation_map));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> computation,
HloComputation::CreateFromProto(
module.get(), computation_proto, computation_map,
/*add_fused_computation=*/
[&module](std::unique_ptr<HloComputation> fused_computation) {
module->AddComputationInternal(std::move(fused_computation),
/*is_entry=*/false,
/*uniquify_names=*/false);
}));
CHECK_NE(computation.get(), nullptr);
TF_RET_CHECK(!ContainsKey(computation_map, computation->name()));
string computation_name = computation->name();

View File

@ -52,6 +52,7 @@ namespace xla {
V(kBatchNormInference, "batch-norm-inference") \
V(kBatchNormTraining, "batch-norm-training") \
V(kBitcast, "bitcast") \
V(kBitcastConvert, "bitcast-convert") \
V(kBroadcast, "broadcast") \
V(kCall, "call", kHloOpcodeIsVariadic) \
V(kCeil, "ceil") \

View File

@ -50,7 +50,7 @@ string HloProfilePrinter::ToString(const int64* counters,
/*short_name=*/instruction->short_name, instruction->category,
counters[instruction->profile_index], instruction->flop_count,
instruction->transcendental_count, instruction->bytes_accessed,
instruction->seconds);
instruction->optimal_seconds);
}
result += builder.ToString();

View File

@ -41,7 +41,7 @@ class HloProfilePrinter {
float flop_count;
float transcendental_count;
float bytes_accessed;
float seconds;
float optimal_seconds;
// The index into the profile counters array for the HloInstruction
// corresponding to this HloInstructionInfo.

View File

@ -62,18 +62,11 @@ bool IsRematerializable(const HloInstruction* instruction) {
case HloOpcode::kConstant:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kCustomCall:
case HloOpcode::kOutfeed:
case HloOpcode::kInfeed:
case HloOpcode::kParameter:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTrace:
case HloOpcode::kWhile:
return false;
default:
return true;
return !instruction->HasSideEffect();
}
}

View File

@ -114,11 +114,16 @@ HloRunner::~HloRunner() {
StatusOr<se::DeviceMemoryBase> HloRunner::Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
Shape* result_shape) {
Shape* result_shape, bool run_hlo_passes) {
if (run_hlo_passes) {
TF_ASSIGN_OR_RETURN(
module, backend().compiler()->RunHloPasses(
std::move(module), backend().default_stream_executor()));
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend().compiler()->Compile(std::move(module),
backend().default_stream_executor()));
backend().compiler()->RunBackend(std::move(module),
backend().default_stream_executor()));
se::Stream stream(backend().default_stream_executor());
stream.Init();
@ -193,10 +198,12 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferFromDevice(
StatusOr<std::unique_ptr<Literal>> HloRunner::ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
bool run_hlo_passes) {
Shape result_shape;
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase device_base,
Execute(std::move(module), arguments, &result_shape));
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase device_base,
Execute(std::move(module), arguments, &result_shape, run_hlo_passes));
return TransferFromDevice(result_shape, device_base);
}

View File

@ -65,17 +65,20 @@ class HloRunner {
// Executes the given module with given literals as input and returns the
// result as a Literal. The LiteralPtr type accepts Literal* or
// std::unique_ptr<Literal>.
// If run_hlo_passes is true, the module will be executed without Hlo
// optimization.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<LiteralPtr> literals);
const tensorflow::gtl::ArraySlice<LiteralPtr> literals,
bool run_hlo_passes = true);
// Executes the given module and returns a global data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Shape* result_shape);
Shape* result_shape, bool run_hlo_passes = true);
// Transfers the given literal to the device and returns the data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> TransferToDevice(
@ -90,7 +93,8 @@ class HloRunner {
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments);
arguments,
bool run_hlo_passes = true);
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.
@ -112,14 +116,15 @@ class HloRunner {
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<LiteralPtr> literals) {
const tensorflow::gtl::ArraySlice<LiteralPtr> literals,
bool run_hlo_passes) {
std::vector<perftools::gputools::DeviceMemoryBase> arguments;
for (const auto& literal : literals) {
TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument,
TransferToDevice(*literal));
arguments.push_back(argument);
}
return ExecuteAndTransfer(std::move(module), arguments);
return ExecuteAndTransfer(std::move(module), arguments, run_hlo_passes);
}
} // namespace xla

View File

@ -160,7 +160,59 @@ bool HloSharding::HasUniqueDevice() const {
}
}
Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
if (!ShapeUtil::IsTuple(shape)) {
return tensorflow::errors::InvalidArgument(
StrCat("Sharding is tuple-shaped but validation shape is not."));
}
// The easiest way to get the number of elements in a nested tuple is just to
// create a shape tree. We could call GetAsShapeTree, but that will try and
// apply our tuple_shardings_ to the shape tree, and that might cause a crash
// at this point as we haven't validated them.
ShapeTree<bool> bool_shape_tree(shape, false);
int64 num_leaves =
std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end());
if (num_leaves != tuple_elements_.size()) {
return tensorflow::errors::InvalidArgument(
StrCat("Validation tuple shape has ", num_leaves,
" leaf elements, but this sharding contains ",
tuple_elements_.size(), " elements."));
}
// Now we've validated the number of tuple elements, it's safe to request a
// shape tree.
ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
for (const auto& index_to_sharding : shape_tree.leaves()) {
Status status = index_to_sharding.second.ValidateNonTuple(
ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
if (!status.ok()) {
tensorflow::errors::AppendToMessage(
&status, StrCat("Note: While validating sharding tuple element ",
index_to_sharding.first.ToString(), " which is ",
index_to_sharding.second.ToString()));
return status;
}
}
return Status::OK();
}
Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
Status status = IsTuple() ? ValidateTuple(shape, num_devices)
: ValidateNonTuple(shape, num_devices);
if (!status.ok()) {
tensorflow::errors::AppendToMessage(
&status, StrCat("Note: While validating sharding ", ToString(),
" against shape ", ShapeUtil::HumanString(shape)));
}
return status;
}
Status HloSharding::ValidateNonTuple(const Shape& shape,
int64 num_devices) const {
if (ShapeUtil::IsTuple(shape)) {
return tensorflow::errors::InvalidArgument(
StrCat("Validation shape is a tuple but sharding is not."));
}
if (replicated_) {
return Status::OK();
}
@ -174,13 +226,11 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
// Don't overwrite a bad status, so we report the first error.
if (status.ok()) {
if (core >= num_devices) {
status =
tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
"core ", core, " > ", num_devices, " in tile assignment"));
status = tensorflow::errors::InvalidArgument(StrCat(
"core ", core, " > ", num_devices, " in tile assignment"));
} else if (seen_cores.count(core) != 0) {
status =
tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
"core ", core, " is not unique in tile assignment"));
status = tensorflow::errors::InvalidArgument(
StrCat("core ", core, " is not unique in tile assignment"));
}
}
seen_cores.insert(core);
@ -214,9 +264,9 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
auto tile_dim = tile_shape_.dimensions(i);
auto shape_dim = shape.dimensions(i);
if (tile_dim > shape_dim) {
return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
"Tile is larger than input shape (dimension ", i, ", ", tile_dim,
" > ", shape_dim));
return tensorflow::errors::InvalidArgument(
StrCat("Tile is larger than input shape (dimension ", i, ", ",
tile_dim, " > ", shape_dim));
}
}
@ -226,10 +276,10 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
int64 expected_dim =
CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i));
if (tile_assignment_.dimensions()[i] != expected_dim) {
return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
"Tile assignment tensor has incorrect shape. Dimension ", i,
" expected ", expected_dim, " but got ",
tile_assignment_.dimensions()[i]));
return tensorflow::errors::InvalidArgument(
StrCat("Tile assignment tensor has incorrect shape. Dimension ", i,
" expected ", expected_dim, " but got ",
tile_assignment_.dimensions()[i]));
}
}

View File

@ -222,6 +222,11 @@ class HloSharding {
tile_assignment_({0}),
tuple_elements_(tuple_shardings) {}
// Internal helper to validate a tuple sharding.
Status ValidateTuple(const Shape& shape, int64 num_devices) const;
// Internal helper to validate a non-tuple (leaf) sharding.
Status ValidateNonTuple(const Shape& shape, int64 num_devices) const;
bool replicated_;
bool maximal_;
bool tuple_;

View File

@ -145,11 +145,13 @@ TEST_F(HloShardingTest, NestedTuple) {
ShapeUtil::MakeShape(F32, {4, 6}),
});
HloSharding tiled_sharding = HloSharding::Tile(
ShapeUtil::MakeShape(F32, {4, 3}), Array<int64>({{0, 1}}));
OpSharding proto;
proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
*proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
*proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
*proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto();
*proto.add_tuple_shardings() = tiled_sharding.ToProto();
HloSharding tuple_sharding =
HloSharding::FromProto(proto).ConsumeValueOrDie();
@ -157,7 +159,15 @@ TEST_F(HloShardingTest, NestedTuple) {
tuple_sharding.GetAsShapeTree(nested_tuple_shape);
EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1));
EXPECT_EQ(shape_tree.element({2}), tiled_sharding);
EXPECT_IS_OK(tuple_sharding.Validate(nested_tuple_shape, /*num_devices=*/5));
// Test should fail because tuple element count does not match.
EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeTupleShape({}),
/*num_devices=*/5));
// Test should fail because the input type is not a tuple.
EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeShape(F32, {}),
/*num_devices=*/5));
}
TEST_F(HloShardingTest, Hash) {

View File

@ -59,15 +59,17 @@ class ShapeVerifier : public DfsHloVisitor {
}
Status HandleConvert(HloInstruction* convert) override {
if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) {
TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape()))
<< "Unsupported complex->real kConvert";
}
return CheckShape(convert, ShapeInference::InferConvertShape(
convert->operand(0)->shape(),
convert->shape().element_type()));
}
Status HandleBitcastConvert(HloInstruction* convert) override {
return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
convert->operand(0)->shape(),
convert->shape().element_type()));
}
Status HandleCopy(HloInstruction* copy) override {
return CheckUnaryShape(copy);
}
@ -263,6 +265,15 @@ class ShapeVerifier : public DfsHloVisitor {
xla_while->while_body()->ComputeProgramShape().result());
}
Status HandleConditional(HloInstruction* conditional) override {
TF_RETURN_IF_ERROR(CheckShape(
conditional,
conditional->true_computation()->ComputeProgramShape().result()));
return CheckShape(
conditional,
conditional->false_computation()->ComputeProgramShape().result());
}
Status HandlePad(HloInstruction* pad) override {
return CheckShape(pad,
ShapeInference::InferPadShape(pad->operand(0)->shape(),
@ -571,7 +582,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
// or ComputationLowerer::Visit()
TF_RET_CHECK(instruction->dimensions().size() ==
ShapeUtil::Rank(instruction->operand(0)->shape()))
<< "Broadcast HLO has invalid number of dimensions.";
<< "Broadcast HLO has invalid number of dimensions.";
} else if (instruction->opcode() == HloOpcode::kWhile) {
auto* while_cond = instruction->while_condition();
auto* while_body = instruction->while_body();

View File

@ -33,7 +33,9 @@ namespace xla {
switch (instruction.opcode()) {
// Cheap instructions.
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kBroadcast:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
@ -53,15 +55,14 @@ namespace xla {
case HloOpcode::kInfeed:
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kAnd:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kOutfeed:
case HloOpcode::kPad:
case HloOpcode::kReal:
@ -88,9 +89,9 @@ namespace xla {
// Expensive instructions.
case HloOpcode::kAtan2:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kConvolution:
@ -104,19 +105,19 @@ namespace xla {
case HloOpcode::kMap:
case HloOpcode::kParameter:
case HloOpcode::kPower:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kRemainder:
case HloOpcode::kRng:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kSort:
case HloOpcode::kTanh:
case HloOpcode::kTrace:
case HloOpcode::kWhile:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
return true;
}

View File

@ -69,11 +69,19 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
return pipeline.Run(hlo_module).status();
}
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::Compile(
StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
std::unique_ptr<HloModule> hlo_module,
se::StreamExecutor* /*stream_exec*/) {
VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
return std::move(hlo_module);
}
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec) {
TF_RET_CHECK(stream_exec != nullptr);
VLOG(1) << "Generate graph " << hlo_module->name();
VLOG(1) << "Run backend " << hlo_module->name();
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));

View File

@ -43,8 +43,12 @@ class InterpreterCompiler : public Compiler {
InterpreterCompiler() {}
~InterpreterCompiler() override {}
StatusOr<std::unique_ptr<Executable>> Compile(
std::unique_ptr<HloModule> hlo_modules,
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> hlo_module,
perftools::gputools::StreamExecutor* stream_exec) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> hlo_module,
perftools::gputools::StreamExecutor* stream_exec) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(

View File

@ -103,7 +103,7 @@ namespace {
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
// where 'user' is a user of an alias of 'intruction' at 'index', and
// where 'user' is a user of an alias of 'instruction' at 'index', and
// 'operand_index' is the operand index at which the alias appears in the
// operand list of 'user'.
std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
@ -243,6 +243,31 @@ bool CanShareOperandBufferWithUser(
std::vector<int64> operand_indices = user->OperandIndices(operand);
return operand_indices.size() == 1 && operand_indices[0] == 0;
}
if (user->opcode() == HloOpcode::kCall) {
// TODO(b/62548313): Remove when buffer assignment is module scoped and
// does not assign buffers to calls.
// Find called computation parameter associated with 'operand'.
const std::vector<int64> operand_indices = user->OperandIndices(operand);
if (operand_indices.size() > 1) {
return false;
}
CHECK_EQ(1, operand_indices.size());
auto* param = user->to_apply()->parameter_instruction(operand_indices[0]);
// Get all uses of 'operand' at 'index' in called computation.
auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index,
points_to_analysis);
// Return true iff:
// *) There exists exactly one use of 'operand' in called computation.
// *) The unique use is by the root instruction of called computation.
// (Note: we check the root of the called computation, because the
// root result buffer is required to alias with the Call result buffer).
// *) The root instruction of the called computation is element-wise on
// 'operand'.
auto* callee_root = user->to_apply()->root_instruction();
return param_uses.size() == 1 && param_uses[0].first == callee_root &&
callee_root->IsElementwiseOnOperand(param_uses[0].second);
}
// Check if 'user' is element-wise.
return user->IsElementwise();
}
@ -322,6 +347,31 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand,
std::vector<int64> operand_indices = user->OperandIndices(operand);
return operand_indices.size() == 1 && operand_indices[0] == 0;
}
if (user->opcode() == HloOpcode::kCall) {
// Get all uses of value defined by 'operand' at 'operand_index'.
const auto& uses =
dataflow.GetValueDefinedAt(operand, operand_index).uses();
// Return true iff:
// *) There exists two uses of 'operand'.
// *) One use is by 'user' (caller).
// *) One use is by root instruction of called computation (callee root).
// (Note: we check the root of the called computation, because the
// root result buffer is required to alias with the Call result buffer).
// *) The root instruction of the called computation is element-wise on
// 'operand'.
const bool found_caller_use =
std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) {
return use.instruction == user;
}) != uses.end();
auto* callee_root = user->to_apply()->root_instruction();
const bool found_elementwise_callee_use =
std::find_if(
uses.begin(), uses.end(), [callee_root](const HloUse& use) {
return use.instruction == callee_root &&
callee_root->IsElementwiseOnOperand(use.operand_number);
}) != uses.end();
return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
}
// Check if 'user' is element-wise.
return user->IsElementwise();
}

View File

@ -415,5 +415,44 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_));
}
// Tests that Call can alias operand buffer if the only use of the operand
// in the called computation is an elementwise instruction.
TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
Shape shape = ShapeUtil::MakeShape(F32, {8});
// Build sub-computation with fusion root.
auto sub_builder = HloComputation::Builder(TestName() + "_sub");
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "sub_param"));
auto one = sub_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto ones = sub_builder.AddInstruction(
HloInstruction::CreateBroadcast(shape, one, {1}));
auto add = sub_builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
module_ = CreateNewModule();
auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
sub_computation->CreateFusionInstruction({add, ones},
HloInstruction::FusionKind::kLoop);
// Build entry-computation with kCall which calls 'sub_computation'.
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
auto reverse =
builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
auto call = builder.AddInstruction(
HloInstruction::CreateCall(shape, {reverse}, sub_computation));
computation_ = module_->AddEntryComputation(builder.Build());
RunAnalysis();
EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {},
*points_to_analysis_));
EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {},
*dataflow_analysis_));
}
} // namespace
} // namespace xla

View File

@ -27,8 +27,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
"Model partitioning not implemented for the CPU/GPU compilers!");
}
TF_ASSIGN_OR_RETURN(
modules[i], RunHloPasses(std::move(modules[i]), stream_execs[i][0]));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
Compile(std::move(modules[i]), stream_execs[i][0]));
RunBackend(std::move(modules[i]), stream_execs[i][0]));
result.push_back(std::move(executable));
}

View File

@ -58,10 +58,14 @@ class LLVMCompiler : public Compiler {
void RemovePostOptimizationHook() { user_post_optimization_hook_ = nullptr; }
// Bring in
// StatusOr<std::unique_ptr<Executable>> Compile(
// std::unique_ptr<HloModule> module,
// perftools::gputools::StreamExecutor* executor)
using Compiler::Compile;
// StatusOr<std::unique_ptr<Executable>> RunBackend(
// std::unique_ptr<HloModule> module,
// perftools::gputools::StreamExecutor* stream_exec)
// StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
// std::unique_ptr<HloModule> module,
// perftools::gputools::StreamExecutor* stream_exec)
using Compiler::RunBackend;
using Compiler::RunHloPasses;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::vector<std::unique_ptr<HloModule>> modules,

View File

@ -40,11 +40,24 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
inline bool CanEmitFusedDynamicUpdateSliceInPlace(
HloInstruction* fusion, const BufferAssignment& assignment) {
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop &&
fusion->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice &&
CanUpdateDynamicSliceInPlace(fusion->fused_expression_root(),
assignment);
HloInstruction* fused_root = fusion->fused_expression_root();
if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice ||
fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) {
return false;
}
// Walk DynamicUpdateSlice operand(0) to fused parameter and get its
// associated operand. See if it shares an allocation with this operand.
HloInstruction* fusion_operand;
ShapeIndex index;
std::tie(fusion_operand, index) =
fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
if (fusion_operand->opcode() != HloOpcode::kParameter) {
return false;
}
auto* operand = fusion->operand(fusion_operand->parameter_number());
return assignment.HasAllocationAt(operand, index) &&
assignment.HasAllocationAt(fusion, {}) &&
assignment.SharesSliceAtIndex(fusion, {}, operand, index);
}
// Emits IR for running the given dynamic-update-slice op in-place -- that is,

View File

@ -430,9 +430,12 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
/*include_unreachable_instructions=*/
true));
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend->compiler()->Compile(std::move(module), executor));
backend->compiler()->RunBackend(std::move(module), executor));
if (!other_directory_path.empty()) {
executable->set_session_module(std::move(session_module));
@ -1361,6 +1364,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status =
computation->AddConvertInstruction(arg->convert_request());
break;
case OpRequest::kBitcastConvertRequest:
handle_status = computation->AddBitcastConvertInstruction(
arg->bitcast_convert_request());
break;
case OpRequest::kConvolveRequest:
handle_status =
computation->AddConvolveInstruction(arg->convolve_request());

View File

@ -441,6 +441,14 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
const Shape& operand_shape, PrimitiveType new_element_type) {
auto old_element_type = operand_shape.element_type();
if (primitive_util::IsComplexType(old_element_type) &&
!primitive_util::IsComplexType(new_element_type)) {
return Unimplemented(
"Unsupported conversion from complex to real type: %s => %s",
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
@ -454,6 +462,36 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
}
/* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
const Shape& operand_shape, PrimitiveType new_element_type) {
auto old_element_type = operand_shape.element_type();
if (primitive_util::IsComplexType(old_element_type) !=
primitive_util::IsComplexType(new_element_type)) {
return Unimplemented(
"Unsupported conversion between real and complex types: %s => %s",
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
// are valid. For now we just reject them, though.
return InvalidArgument(
"cannot convert from or to tuple type; requested conversion: %s => %s",
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
if (primitive_util::BitWidth(old_element_type) !=
primitive_util::BitWidth(new_element_type)) {
return InvalidArgument(
"cannot bitcast types with different bit-widths: %s => %s",
PrimitiveType_Name(old_element_type).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
}
/* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
const Shape& operand_shape, const int exponent_bits,
const int mantissa_bits) {

View File

@ -204,6 +204,13 @@ class ShapeInference {
static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
PrimitiveType new_element_type);
// Helper that validates the given operand shape can be bitcast converted to
// the target output_shape via a bitcast convert instruction -- the
// requirement is that the shape is identical except for the element type and
// the element types have identical bit-widths.
static StatusOr<Shape> InferBitcastConvertShape(
const Shape& operand_shape, PrimitiveType new_element_type);
// Helper that validates the input data type for a reduce-precision operation,
// and returns the result shape.
static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,

View File

@ -994,6 +994,32 @@ StatusOr<ComputationDataHandle> UserComputation::AddConvertInstruction(
return handle;
}
StatusOr<ComputationDataHandle> UserComputation::AddBitcastConvertInstruction(
const ConvertRequest& convert_request) {
tensorflow::mutex_lock lock(mutex_);
TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
LookUpRequest(convert_request.operand()));
TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape(
operand->output_shape(),
convert_request.new_element_type()));
ComputationDataHandle handle = CreateComputationDataHandle();
OperationRequest& request =
(*session_computation_.mutable_requests())[handle.handle()];
*request.mutable_output_handle() = handle;
*request.mutable_output_shape() = new_shape;
*request.mutable_request()->mutable_bitcast_convert_request() =
convert_request;
VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal()
<< "), data handle " << handle.handle() << ": "
<< convert_request.ShortDebugString();
return handle;
}
StatusOr<ComputationDataHandle> UserComputation::AddReducePrecisionInstruction(
const ReducePrecisionRequest& reduce_precision_request) {
tensorflow::mutex_lock lock(mutex_);
@ -2370,6 +2396,13 @@ static void ForEachOperand(
break;
}
case OpRequest::kBitcastConvertRequest: {
const ConvertRequest& convert_request =
request.request().bitcast_convert_request();
apply(convert_request.operand());
break;
}
case OpRequest::kWhileRequest: {
const WhileRequest& while_request = request.request().while_request();
apply(while_request.init());
@ -2954,6 +2987,15 @@ void ComputationLowerer::Visit(
break;
}
case OpRequest::kBitcastConvertRequest: {
const ConvertRequest& convert_request =
request.request().bitcast_convert_request();
HloInstruction* operand = lookup_instruction(convert_request.operand());
hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert(
request.output_shape(), operand));
break;
}
case OpRequest::kWhileRequest: {
const WhileRequest& while_request = request.request().while_request();
CHECK_EQ(2, request.embedded_computation_versions_size());
@ -2978,6 +3020,25 @@ void ComputationLowerer::Visit(
HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs());
HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs());
auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop());
if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) {
if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
// lhs side is being implicitly broadcast. Change to explicit.
lhs =
ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
}
if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
rhs =
ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
}
if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) {
ehs =
ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape());
}
}
hlo_instruction = add_instruction(HloInstruction::CreateTernary(
request.output_shape(), hlo_opcode, lhs, rhs, ehs));
break;
@ -3137,7 +3198,7 @@ void ComputationLowerer::Visit(
LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
}
(*instructions)[handle.handle()] = hlo_instruction;
}
} // NOLINT(readability/fn_size)
} // namespace

View File

@ -70,7 +70,7 @@ class UserComputation {
// Enqueues a pad instruction onto this user computation.
StatusOr<ComputationDataHandle> AddPadInstruction(
const PadRequest& parameter_request);
const PadRequest& pad_request);
// Enqueues a tracing instruction onto this user computation.
// Returns an error status if the operand cannot be resolved.
@ -105,7 +105,7 @@ class UserComputation {
// Enqueues a ternary instruction onto this user computation.
// Returns an error status if the operand indices are out of bounds.
StatusOr<ComputationDataHandle> AddTernaryInstruction(
const TernaryOpRequest& request);
const TernaryOpRequest& ternary_request);
// Enqueues a variadic instruction onto this user computation.
// Returns an error status if the operand indices are out of bounds.
@ -179,26 +179,30 @@ class UserComputation {
// Enqueues a concatenate instruction onto this user computation.
StatusOr<ComputationDataHandle> AddConcatenateInstruction(
const ConcatenateRequest& slice_request);
const ConcatenateRequest& concatenate_request);
// Enqueues a convert instruction onto this user computation.
StatusOr<ComputationDataHandle> AddConvertInstruction(
const ConvertRequest& convert_request);
// Enqueues a bitcast element instruction onto this user computation.
StatusOr<ComputationDataHandle> AddBitcastConvertInstruction(
const ConvertRequest& convert_request);
// Enqueues a reduce instruction onto this user computation.
StatusOr<ComputationDataHandle> AddReduceInstruction(
const ReduceRequest& reduce_request,
const UserComputation& reduction_computation);
const UserComputation& to_apply_computation);
// Enqueues a windowed reduce instruction onto this user computation.
StatusOr<ComputationDataHandle> AddReduceWindowInstruction(
const ReduceWindowRequest& reduce_window_request,
const UserComputation& reduction_computation);
const UserComputation& to_apply_computation);
// Enqueues a select-and-scatter instruction onto this user
// computation.
StatusOr<ComputationDataHandle> AddSelectAndScatterInstruction(
const SelectAndScatterRequest& scatter_to_selected_window_element_request,
const SelectAndScatterRequest& select_and_scatter_request,
const UserComputation& select_computation,
const UserComputation& scatter_computation);

View File

@ -403,6 +403,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// Compute the shape of the while op after we remove the dead indices.
std::vector<Shape> new_while_tuple_elem_shapes;
new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size());
for (int64 old_idx : new_to_old_tuple_idx) {
new_while_tuple_elem_shapes.push_back(
while_init->shape().tuple_shapes(old_idx));
@ -469,6 +470,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
while_body_replacements = make_while_computation_replacements(while_body);
std::vector<HloInstruction*> new_while_body_root_elems;
new_while_body_root_elems.reserve(new_to_old_tuple_idx.size());
for (int64 old_idx : new_to_old_tuple_idx) {
new_while_body_root_elems.push_back(
while_body_root->mutable_operand(old_idx));
@ -483,6 +485,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// clean this up in the common case where while_init is a tuple op. (It's
// definitely tuple-shaped, but it's not necessarily a tuple op.)
std::vector<HloInstruction*> new_while_init_elems;
new_while_init_elems.reserve(new_to_old_tuple_idx.size());
for (int64 old_idx : new_to_old_tuple_idx) {
new_while_init_elems.push_back(
computation->AddInstruction(HloInstruction::CreateGetTupleElement(

View File

@ -382,7 +382,6 @@ xla_test(
name = "params_test",
srcs = ["params_test.cc"],
shard_count = 30,
tags = ["optonly"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal_util",
@ -512,6 +511,7 @@ xla_test(
name = "array_elementwise_ops_test",
srcs = ["array_elementwise_ops_test.cc"],
shard_count = 25,
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
@ -770,6 +770,38 @@ xla_test(
],
)
xla_test(
name = "bfloat16_test",
srcs = ["bfloat16_test.cc"],
shard_count = 40,
deps = [
":test_utils",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
xla_test(
name = "slice_test",
srcs = ["slice_test.cc"],
@ -1230,6 +1262,23 @@ xla_test(
],
)
xla_test(
name = "bitcast_convert_test",
srcs = ["bitcast_convert_test.cc"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
],
)
xla_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.cc"],

View File

@ -0,0 +1,75 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
class Bfloat16Test : public ClientLibraryTestBase {
protected:
const ErrorSpec error_spec_{0.001, 0.001};
};
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
DISABLED_ON_CPU(ScalarOperation)))) {
ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
builder.Add(x, y);
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(3.0f), {},
error_spec_);
}
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
DISABLED_ON_CPU(NegateScalarF16)))) {
ComputationBuilder builder(client_, TestName());
builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {},
error_spec_);
}
} // namespace
} // namespace xla

View File

@ -0,0 +1,141 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <limits>
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
class BitcastConvertTest : public ClientLibraryTestBase {
public:
explicit BitcastConvertTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform) {
mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
mutable_debug_options()->add_xla_disable_hlo_passes("inline");
}
};
TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<int32>({42, 64});
builder.BitcastConvertType(a, S32);
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({42.0f, 64.0f});
builder.BitcastConvertType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) {
ComputationBuilder builder(client_, TestName());
auto a =
builder.ConstantR1<int32>({0, static_cast<int32>(0x80000000), 0x3F800000,
static_cast<int32>(0xBF800000), 0x3F000000,
static_cast<int32>(0xBF000000)});
builder.BitcastConvertType(a, F32);
std::vector<float> expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f};
ComputeAndCompareR1<float>(&builder, expected, {});
}
XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<int32>({});
builder.BitcastConvertType(a, F32);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({42.6, 64.4});
builder.BitcastConvertType(a, S32);
std::vector<int32> expected = {0x422a6666, 0x4280cccd};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
TEST_F(BitcastConvertTest, ConvertS32Extremes) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<int32>(
{std::numeric_limits<int32>::min(), std::numeric_limits<int32>::max()});
builder.BitcastConvertType(a, F32);
std::vector<float> expected = {-0.0f, NAN};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0, 0));
}
TEST_F(BitcastConvertTest, ConvertMapToS32) {
ComputationBuilder builder(client_, TestName());
auto b = builder.CreateSubBuilder("convert");
auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in");
b->BitcastConvertType(param, S32);
auto a = builder.ConstantR1<float>({42.0f, 64.0f});
builder.Map({a}, b->BuildAndNoteError(), {0});
std::vector<int32> expected = {0x42280000, 0x42800000};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
TEST_F(BitcastConvertTest, ConvertMapToF32) {
ComputationBuilder builder(client_, TestName());
auto b = builder.CreateSubBuilder("convert");
auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in");
b->BitcastConvertType(param, F32);
auto a = builder.ConstantR1<int32>({0x42280000, 0x42800000});
builder.Map({a}, b->BuildAndNoteError(), {0});
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {});
}
// Regression test for b/31758660. When ReshapeMover transforms
// input -> reshape -> convert
// to
// input -> convert -> reshape
// the new convert should have the same element type as the old convert.
TEST_F(BitcastConvertTest, ConvertReshape) {
ComputationBuilder builder(client_, TestName());
auto input = builder.ConstantR1<int32>({0x42280000});
auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
builder.BitcastConvertType(reshape, F32);
ComputeAndCompareR0<float>(&builder, 42.0f, {});
}
} // namespace
} // namespace xla

View File

@ -333,6 +333,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =

View File

@ -19,8 +19,11 @@ namespace xla {
StatusOr<std::unique_ptr<Executable>> CodegenTestBase::CompileToExecutable(
std::unique_ptr<HloModule> hlo_module) {
return backend().compiler()->Compile(std::move(hlo_module),
backend().default_stream_executor());
TF_ASSIGN_OR_RETURN(hlo_module, backend().compiler()->RunHloPasses(
std::move(hlo_module),
backend().default_stream_executor()));
return backend().compiler()->RunBackend(std::move(hlo_module),
backend().default_stream_executor());
}
StatusOr<std::unique_ptr<AotCompilationResult>>

View File

@ -458,6 +458,54 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) {
error_spec_);
}
// Test fixture to run convolution tests with and without convolution
// canonicalization enabled.
class ConvolveWithAndWithoutCanonicalization
: public ConvolutionTest,
public ::testing::WithParamInterface<bool> {};
XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
DISABLED_ON_GPU(Convolve2D_NoSpatialDims)) {
if (GetParam()) {
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
"convolution-canonicalization");
}
ComputationBuilder builder(client_, TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29});
Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
ConvolutionDimensionNumbers dnums;
dnums.set_input_feature_dimension(0);
dnums.set_input_batch_dimension(1);
dnums.set_kernel_input_feature_dimension(0);
dnums.set_kernel_output_feature_dimension(1);
dnums.set_output_batch_dimension(0);
dnums.set_output_feature_dimension(1);
auto conv = builder.ConvWithGeneralDimensions(input, filter, {},
Padding::kValid, dnums);
Array2D<float> param0(4, 29);
param0.FillUnique();
Array2D<float> param1(4, 10);
param1.FillUnique();
Array2D<float> expected_result(29, 10);
expected_result.Fill(0);
ComputeAndCompare(
&builder, conv,
{*Literal::CreateFromArray(param0), *Literal::CreateFromArray(param1)},
error_spec_);
}
INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation,
ConvolveWithAndWithoutCanonicalization,
::testing::Values(true, false));
struct Convolve1DTestParam {
int64 input_feature;
int64 output_feature;

View File

@ -340,6 +340,9 @@ class NearComparator {
multi_index_.resize(expected.shape().dimensions_size(), 0);
switch (expected.shape().element_type()) {
case BF16:
ExpectLiteralsNear<bfloat16>(expected, actual, 0);
break;
case F32:
ExpectLiteralsNear<float>(expected, actual, 0);
break;
@ -525,6 +528,13 @@ void NearComparator::ExpectNear<complex64>(complex64 expected, complex64 actual,
<< message;
}
template <>
bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
bfloat16 actual) {
return ExpectValuesNear(static_cast<float>(expected),
static_cast<float>(actual));
}
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(

View File

@ -73,8 +73,8 @@ class LLVMCompilerTest : public ::testing::Test {
compiler->SetPostOptimizationHook(post_opt_hook);
ASSERT_TRUE(compiler
->Compile(std::move(hlo_module),
backend_->default_stream_executor())
->RunBackend(std::move(hlo_module),
backend_->default_stream_executor())
.ok());
// Test that hooks were called.

Some files were not shown because too many files have changed in this diff Show More