Merge pull request #14804 from yifeif/branch_176676125
Branch 176676125
This commit is contained in:
commit
79422ab39b
@ -645,6 +645,7 @@ filegroup(
|
||||
"//tensorflow/tools/test:all_files",
|
||||
"//tensorflow/user_ops:all_files",
|
||||
"//third_party/hadoop:all_files",
|
||||
"//third_party/mpi:all_files",
|
||||
"//third_party/sycl:all_files",
|
||||
"//third_party/sycl/sycl:all_files",
|
||||
],
|
||||
|
||||
@ -939,13 +939,17 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
|
||||
dim_vec.reserve(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dim_vec.push_back(ic->MakeDim(dims[i]));
|
||||
tensorflow::shape_inference::ShapeHandle new_shape;
|
||||
if (num_dims != -1) {
|
||||
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
|
||||
dim_vec.reserve(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dim_vec.push_back(ic->MakeDim(dims[i]));
|
||||
}
|
||||
new_shape = ic->MakeShape(dim_vec);
|
||||
} else {
|
||||
new_shape = ic->UnknownShape();
|
||||
}
|
||||
|
||||
tensorflow::shape_inference::ShapeHandle new_shape = ic->MakeShape(dim_vec);
|
||||
status->status = graph->refiner.SetShape(node, output.index, new_shape);
|
||||
}
|
||||
|
||||
|
||||
@ -287,6 +287,13 @@ TEST(CAPI, SetShape) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(-1, num_dims);
|
||||
|
||||
// Set the shape to be unknown, expect no change.
|
||||
TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(-1, num_dims);
|
||||
|
||||
// Set the shape to be 2 x Unknown
|
||||
int64_t dims[] = {2, -1};
|
||||
TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
|
||||
@ -315,7 +322,17 @@ TEST(CAPI, SetShape) {
|
||||
EXPECT_EQ(dims[0], returned_dims[0]);
|
||||
EXPECT_EQ(dims[1], returned_dims[1]);
|
||||
|
||||
// Try to set 'unknown' on the shape and see that
|
||||
// Try to set 'unknown' with unknown rank on the shape and see that
|
||||
// it doesn't change.
|
||||
TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(2, num_dims);
|
||||
EXPECT_EQ(2, returned_dims[0]);
|
||||
EXPECT_EQ(3, returned_dims[1]);
|
||||
|
||||
// Try to set 'unknown' with same rank on the shape and see that
|
||||
// it doesn't change.
|
||||
dims[0] = -1;
|
||||
dims[1] = -1;
|
||||
|
||||
@ -571,6 +571,12 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
status->status = ctx->func_lib_def.AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
tensorflow::mutex_lock l(ctx->functions_mu);
|
||||
status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
|
||||
|
||||
@ -200,6 +200,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
const char* serialized_function_def,
|
||||
size_t size, TF_Status* status);
|
||||
|
||||
// Adds a function (created from TF_GraphToFunction or
|
||||
// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with
|
||||
// TFE_Execute by creating an op with the same name as the function.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx,
|
||||
TF_Function* function,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
@ -295,6 +295,66 @@ TEST(CAPI, Execute) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, Function) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
TF_OperationDescription* arg_descr =
|
||||
TF_NewOperation(function_graph, "Placeholder", "arg");
|
||||
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_OperationDescription* id_descr =
|
||||
TF_NewOperation(function_graph, "Identity", "id");
|
||||
TF_SetAttrType(id_descr, "T", TF_INT32);
|
||||
TF_AddInput(id_descr, {arg, 0});
|
||||
TF_Operation* id = TF_FinishOperation(id_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_Output input{arg, 0};
|
||||
TF_Output output{id, 0};
|
||||
TF_Function* fn =
|
||||
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
|
||||
&output, nullptr, nullptr, "test", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteGraph(function_graph);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_ContextAddFunction(ctx, fn, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteFunction(fn);
|
||||
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, 1);
|
||||
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
|
||||
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, h, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
|
||||
std::vector<TFE_TensorHandle*> result;
|
||||
result.push_back(nullptr);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, result.data(), &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(num_retvals, 1);
|
||||
|
||||
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TF_DeleteTensor(r);
|
||||
TFE_DeleteTensorHandle(result[0]);
|
||||
TFE_DeleteContext(ctx, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
string MatMulFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
|
||||
@ -485,7 +485,6 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
||||
Status s = vspace.CallBackwardFunction(trace.backward_function,
|
||||
out_gradients, &in_gradients);
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Gradient function failed.";
|
||||
cleanup();
|
||||
return s;
|
||||
}
|
||||
|
||||
@ -421,6 +421,7 @@ tf_cc_test(
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "cc_ops",
|
||||
api_def_srcs = ["//tensorflow/core:base_api_def"],
|
||||
op_lib_names = [
|
||||
"array_ops",
|
||||
"audio_ops",
|
||||
@ -525,6 +526,9 @@ cc_library_with_android_deps(
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
data = [
|
||||
"//tensorflow/core:base_api_def",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -536,6 +540,29 @@ cc_library_with_android_deps(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cc_op_gen_test",
|
||||
srcs = [
|
||||
"framework/cc_op_gen.cc",
|
||||
"framework/cc_op_gen.h",
|
||||
"framework/cc_op_gen_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/cc:ops/op_gen_overrides.pbtxt",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:op_gen_overrides_proto_cc",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_op_op_lib",
|
||||
srcs = ["framework/test_op.cc"],
|
||||
|
||||
@ -18,10 +18,11 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
#include "tensorflow/core/framework/api_def.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb_text.h"
|
||||
@ -35,7 +36,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
const int kRightMargin = 79;
|
||||
@ -297,7 +297,7 @@ string ToCamelCase(const string& str) {
|
||||
// argument to a function.
|
||||
std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
|
||||
static const std::unordered_map<StringPiece, std::pair<const char*, bool>,
|
||||
StringPiece::Hasher>
|
||||
StringPieceHasher>
|
||||
attr_type_map{
|
||||
{"string", {"StringPiece", false}},
|
||||
{"list(string)", {"gtl::ArraySlice<string>", true}},
|
||||
@ -325,29 +325,112 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
|
||||
}
|
||||
|
||||
bool IsCPPKeyword(StringPiece name) {
|
||||
static const std::unordered_set<StringPiece, StringPiece::Hasher>
|
||||
static const std::unordered_set<StringPiece, StringPieceHasher>
|
||||
// Keywords obtained from http://en.cppreference.com/w/cpp/keyword
|
||||
kCPPReserved{
|
||||
"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel",
|
||||
"atomic_commit", "atomic_noexcept", "auto", "bitand", "bitor", "bool",
|
||||
"break", "case", "catch", "char", "char16_t", "char32_t", "class",
|
||||
"compl", "concept", "const", "const_cast", "constexpr", "continue",
|
||||
"decltype", "default", "delete", "do", "double", "dynamic_cast",
|
||||
"else", "enum", "explicit", "export", "extern", "false", "final",
|
||||
"float", "for", "friend", "goto", "if", "import", "inline", "int",
|
||||
"long", "module", "mutable", "namespace", "new", "noexcept", "not",
|
||||
"not_eq", "nullptr", "operator", "or", "or_eq", "override", "private",
|
||||
"protected", "public", "register", "reinterpret_cast", "requires",
|
||||
"return", "short", "signed", "sizeof", "static", "static_assert",
|
||||
"static_cast", "struct", "switch", "synchronized", "template", "this",
|
||||
"thread_local", "throw", "true", "try", "typedef", "typeid",
|
||||
"typename", "union", "unsigned", "using", "virtual", "void",
|
||||
"volatile", "wchar_t", "while", "xor", "xor_eq",
|
||||
"alignas",
|
||||
"alignof",
|
||||
"and",
|
||||
"and_eq",
|
||||
"asm",
|
||||
"atomic_cancel",
|
||||
"atomic_commit",
|
||||
"atomic_noexcept",
|
||||
"auto",
|
||||
"bitand",
|
||||
"bitor",
|
||||
"bool",
|
||||
"break",
|
||||
"case",
|
||||
"catch",
|
||||
"char",
|
||||
"char16_t",
|
||||
"char32_t",
|
||||
"class",
|
||||
"compl",
|
||||
"concept",
|
||||
"const",
|
||||
"const_cast",
|
||||
"constexpr",
|
||||
"continue",
|
||||
"decltype",
|
||||
"default",
|
||||
"delete",
|
||||
"do",
|
||||
"double",
|
||||
"dynamic_cast",
|
||||
"else",
|
||||
"enum",
|
||||
"explicit",
|
||||
"export",
|
||||
"extern",
|
||||
"false",
|
||||
"final",
|
||||
"float",
|
||||
"for",
|
||||
"friend",
|
||||
"goto",
|
||||
"if",
|
||||
"import",
|
||||
"inline",
|
||||
"int",
|
||||
"long",
|
||||
"module",
|
||||
"mutable",
|
||||
"namespace",
|
||||
"new",
|
||||
"noexcept",
|
||||
"not",
|
||||
"not_eq",
|
||||
"nullptr",
|
||||
"operator",
|
||||
"or",
|
||||
"or_eq",
|
||||
"override",
|
||||
"private",
|
||||
"protected",
|
||||
"public",
|
||||
"register",
|
||||
"reinterpret_cast",
|
||||
"requires",
|
||||
"return",
|
||||
"short",
|
||||
"signed",
|
||||
"sizeof",
|
||||
"static",
|
||||
"static_assert",
|
||||
"static_cast",
|
||||
"struct",
|
||||
"switch",
|
||||
"synchronized",
|
||||
"template",
|
||||
"this",
|
||||
"thread_local",
|
||||
"throw",
|
||||
"true",
|
||||
"try",
|
||||
"typedef",
|
||||
"typeid",
|
||||
"typename",
|
||||
"union",
|
||||
"unsigned",
|
||||
"using",
|
||||
"virtual",
|
||||
"void",
|
||||
"volatile",
|
||||
"wchar_t",
|
||||
"while",
|
||||
"xor",
|
||||
"xor_eq",
|
||||
|
||||
// The following are not C++ keywords, but names of local variables
|
||||
// and parameters used in the op constructor. Treating them as
|
||||
// keywords, so that other parameter names don't conflict with these.
|
||||
"builder", "node", "ret", "scope", "unique_name",
|
||||
"builder",
|
||||
"node",
|
||||
"ret",
|
||||
"scope",
|
||||
"unique_name",
|
||||
};
|
||||
return kCPPReserved.count(name) > 0;
|
||||
}
|
||||
@ -385,10 +468,10 @@ bool ArgIsList(const OpDef::ArgDef& arg) {
|
||||
}
|
||||
|
||||
bool HasOptionalAttrs(
|
||||
const OpDef& op_def,
|
||||
const ApiDef& api_def,
|
||||
const std::unordered_map<string, string>& inferred_input_attrs) {
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
for (int i = 0; i < api_def.attr_size(); ++i) {
|
||||
const auto& attr(api_def.attr(i));
|
||||
if ((inferred_input_attrs.find(attr.name()) ==
|
||||
inferred_input_attrs.end()) &&
|
||||
attr.has_default_value()) {
|
||||
@ -398,12 +481,21 @@ bool HasOptionalAttrs(
|
||||
return false;
|
||||
}
|
||||
|
||||
const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
|
||||
for (int i = 0; i < api_def.in_arg_size(); ++i) {
|
||||
if (api_def.in_arg(i).name() == name) {
|
||||
return &api_def.in_arg(i);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct OpInfo {
|
||||
// graph_op_def: The OpDef used by the runtime, has the names that
|
||||
// must be used when calling NodeBuilder.
|
||||
// interface_op_def: The OpDef used in the interface in the generated
|
||||
// code, with possibly overridden names and defaults.
|
||||
explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def,
|
||||
explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
|
||||
const std::vector<string>& aliases);
|
||||
string GetOpAttrStruct() const;
|
||||
string GetConstructorDecl(StringPiece op_name_prefix,
|
||||
@ -423,74 +515,81 @@ struct OpInfo {
|
||||
string comment;
|
||||
|
||||
const OpDef& graph_op_def;
|
||||
const OpDef& op_def;
|
||||
const ApiDef& api_def;
|
||||
const std::vector<string>& aliases;
|
||||
// Map from type attribute to corresponding original argument name.
|
||||
std::unordered_map<string, string> inferred_input_attrs;
|
||||
};
|
||||
|
||||
OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
|
||||
const std::vector<string>& a)
|
||||
: graph_op_def(g_op_def), op_def(i_op_def), aliases(a) {
|
||||
op_name = op_def.name();
|
||||
InferOpAttributes(op_def, &inferred_input_attrs);
|
||||
has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs);
|
||||
OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
|
||||
const std::vector<string>& aliases)
|
||||
: graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
|
||||
op_name = api_def.endpoint(0).name();
|
||||
InferOpAttributes(graph_op_def, &inferred_input_attrs);
|
||||
has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
|
||||
arg_types.push_back("const ::tensorflow::Scope&");
|
||||
arg_names.push_back("scope");
|
||||
|
||||
if (op_def.has_deprecation()) {
|
||||
if (!op_def.summary().empty()) {
|
||||
comment = strings::StrCat(op_def.summary(), "\n");
|
||||
if (graph_op_def.has_deprecation()) {
|
||||
if (!api_def.summary().empty()) {
|
||||
comment = strings::StrCat(api_def.summary(), "\n");
|
||||
}
|
||||
strings::StrAppend(&comment, "DEPRECATED at GraphDef version ",
|
||||
op_def.deprecation().version(), ":\n",
|
||||
op_def.deprecation().explanation(), ".\n");
|
||||
} else if (op_def.summary().empty()) {
|
||||
graph_op_def.deprecation().version(), ":\n",
|
||||
graph_op_def.deprecation().explanation(), ".\n");
|
||||
} else if (api_def.summary().empty()) {
|
||||
comment = "TODO: add doc.\n";
|
||||
} else {
|
||||
comment = strings::StrCat(op_def.summary(), "\n");
|
||||
comment = strings::StrCat(api_def.summary(), "\n");
|
||||
}
|
||||
if (!op_def.description().empty()) {
|
||||
strings::StrAppend(&comment, "\n", op_def.description(), "\n");
|
||||
if (!api_def.description().empty()) {
|
||||
strings::StrAppend(&comment, "\n", api_def.description(), "\n");
|
||||
}
|
||||
strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
|
||||
|
||||
// Process inputs
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
for (int i = 0; i < api_def.arg_order_size(); ++i) {
|
||||
const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def);
|
||||
const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def);
|
||||
arg_types.push_back(strings::StrCat(
|
||||
"::tensorflow::", ArgIsList(arg) ? "InputList" : "Input"));
|
||||
arg_names.push_back(AvoidCPPKeywords(arg.name()));
|
||||
arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
|
||||
|
||||
// TODO(keveman): Include input type information.
|
||||
StringPiece description = arg.description();
|
||||
StringPiece description = api_def_arg.description();
|
||||
if (!description.empty()) {
|
||||
ConsumeEquals(&description);
|
||||
strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ",
|
||||
arg.description(), "\n");
|
||||
strings::StrAppend(&comment, "* ",
|
||||
AvoidCPPKeywords(api_def_arg.rename_to()), ": ",
|
||||
api_def_arg.description(), "\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Process attrs
|
||||
string required_attrs_comment;
|
||||
string optional_attrs_comment;
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
|
||||
// ApiDef attributes must be in the same order as in OpDef since
|
||||
// we initialize ApiDef based on OpDef.
|
||||
const auto& attr(graph_op_def.attr(i));
|
||||
const auto& api_def_attr(api_def.attr(i));
|
||||
CHECK_EQ(attr.name(), api_def_attr.name());
|
||||
// Skip inferred arguments
|
||||
if (inferred_input_attrs.count(attr.name()) > 0) continue;
|
||||
|
||||
const auto entry = AttrTypeName(attr.type());
|
||||
const auto attr_type_name = entry.first;
|
||||
const bool use_const = entry.second;
|
||||
string attr_name = AvoidCPPKeywords(attr.name());
|
||||
string attr_name = AvoidCPPKeywords(api_def_attr.rename_to());
|
||||
|
||||
string attr_comment;
|
||||
if (!attr.description().empty()) {
|
||||
if (!api_def_attr.description().empty()) {
|
||||
// TODO(keveman): Word wrap and indent this, to handle multi-line
|
||||
// descriptions.
|
||||
strings::StrAppend(&attr_comment, "* ", attr_name, ": ",
|
||||
attr.description(), "\n");
|
||||
api_def_attr.description(), "\n");
|
||||
}
|
||||
if (attr.has_default_value()) {
|
||||
if (api_def_attr.has_default_value()) {
|
||||
strings::StrAppend(&optional_attrs_comment, attr_comment);
|
||||
} else {
|
||||
strings::StrAppend(&required_attrs_comment, attr_comment);
|
||||
@ -508,44 +607,49 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
|
||||
}
|
||||
|
||||
// Process outputs
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
const auto& arg = op_def.output_arg(i);
|
||||
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
|
||||
// ApiDef arguments must be in the same order as in OpDef since
|
||||
// we initialize ApiDef based on OpDef.
|
||||
const auto& arg = graph_op_def.output_arg(i);
|
||||
const auto& api_def_arg(api_def.out_arg(i));
|
||||
CHECK_EQ(arg.name(), api_def_arg.name());
|
||||
|
||||
bool is_list = ArgIsList(arg);
|
||||
output_types.push_back(
|
||||
strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output"));
|
||||
output_names.push_back(AvoidCPPKeywords(arg.name()));
|
||||
output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
|
||||
is_list_output.push_back(is_list);
|
||||
}
|
||||
|
||||
strings::StrAppend(&comment, "\nReturns:\n");
|
||||
if (op_def.output_arg_size() == 0) { // No outputs.
|
||||
if (graph_op_def.output_arg_size() == 0) { // No outputs.
|
||||
strings::StrAppend(&comment, "* the created `Operation`\n");
|
||||
} else if (op_def.output_arg_size() == 1) { // One output
|
||||
} else if (graph_op_def.output_arg_size() == 1) { // One output
|
||||
if (is_list_output[0]) {
|
||||
strings::StrAppend(&comment, "* `OutputList`: ");
|
||||
} else {
|
||||
strings::StrAppend(&comment, "* `Output`: ");
|
||||
}
|
||||
if (op_def.output_arg(0).description().empty()) {
|
||||
strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(),
|
||||
if (api_def.out_arg(0).description().empty()) {
|
||||
strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(),
|
||||
" tensor.\n");
|
||||
} else {
|
||||
// TODO(josh11b): Word wrap this.
|
||||
strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n");
|
||||
strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n");
|
||||
}
|
||||
} else { // Multiple outputs.
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
|
||||
if (is_list_output[i]) {
|
||||
strings::StrAppend(&comment, "* `OutputList`");
|
||||
} else {
|
||||
strings::StrAppend(&comment, "* `Output`");
|
||||
}
|
||||
strings::StrAppend(&comment, " ", output_names[i]);
|
||||
if (op_def.output_arg(i).description().empty()) {
|
||||
if (api_def.out_arg(i).description().empty()) {
|
||||
strings::StrAppend(&comment, "\n");
|
||||
} else {
|
||||
// TODO(josh11b): Word wrap this.
|
||||
strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(),
|
||||
strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(),
|
||||
"\n");
|
||||
}
|
||||
}
|
||||
@ -564,19 +668,20 @@ string OpInfo::GetOpAttrStruct() const {
|
||||
string struct_fields;
|
||||
string setters;
|
||||
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
|
||||
const auto& attr(graph_op_def.attr(i));
|
||||
const auto& api_def_attr(api_def.attr(i));
|
||||
// If attr will be inferred or it doesn't have a default value, don't
|
||||
// add it to the struct.
|
||||
if ((inferred_input_attrs.find(attr.name()) !=
|
||||
inferred_input_attrs.end()) ||
|
||||
!attr.has_default_value()) {
|
||||
!api_def_attr.has_default_value()) {
|
||||
continue;
|
||||
}
|
||||
const auto entry = AttrTypeName(attr.type());
|
||||
const auto attr_type_name = entry.first;
|
||||
const bool use_const = entry.second;
|
||||
const string camel_case_name = ToCamelCase(attr.name());
|
||||
const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
|
||||
const string suffix =
|
||||
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
|
||||
const string attr_func_def =
|
||||
@ -584,22 +689,25 @@ string OpInfo::GetOpAttrStruct() const {
|
||||
attr_type_name, use_const ? "&" : "");
|
||||
|
||||
string attr_comment;
|
||||
if (!attr.description().empty()) {
|
||||
strings::StrAppend(&attr_comment, attr.description(), "\n\n");
|
||||
if (!api_def_attr.description().empty()) {
|
||||
strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n");
|
||||
}
|
||||
strings::StrAppend(&attr_comment, "Defaults to ",
|
||||
SummarizeAttrValue(attr.default_value()), "\n");
|
||||
SummarizeAttrValue(api_def_attr.default_value()), "\n");
|
||||
attr_comment = MakeComment(attr_comment, " ");
|
||||
|
||||
strings::StrAppend(&setters, attr_comment);
|
||||
strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n");
|
||||
strings::StrAppend(&setters, " Attrs ret = *this;\n");
|
||||
strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n");
|
||||
strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(),
|
||||
"_ = x;\n");
|
||||
strings::StrAppend(&setters, " return ret;\n }\n\n");
|
||||
|
||||
strings::StrAppend(
|
||||
&struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ",
|
||||
PrintAttrValue(op_def.name(), attr.default_value()), ";\n");
|
||||
&struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(),
|
||||
"_ = ",
|
||||
PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
|
||||
";\n");
|
||||
}
|
||||
|
||||
if (struct_fields.empty()) {
|
||||
@ -676,17 +784,18 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
|
||||
// Add the static functions to set optional attrs
|
||||
if (has_optional_attrs) {
|
||||
strings::StrAppend(&class_decl, "\n");
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
|
||||
const auto& attr(graph_op_def.attr(i));
|
||||
const auto& api_def_attr(api_def.attr(i));
|
||||
if ((inferred_input_attrs.find(attr.name()) !=
|
||||
inferred_input_attrs.end()) ||
|
||||
!attr.has_default_value()) {
|
||||
!api_def_attr.has_default_value()) {
|
||||
continue;
|
||||
}
|
||||
const auto entry = AttrTypeName(attr.type());
|
||||
const auto attr_type_name = entry.first;
|
||||
const bool use_const = entry.second;
|
||||
const string camel_case_name = ToCamelCase(attr.name());
|
||||
const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
|
||||
const string suffix =
|
||||
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
|
||||
const string attr_func_def = strings::StrCat(
|
||||
@ -726,11 +835,11 @@ void OpInfo::GetOutput(string* out) const {
|
||||
strings::StrCat("if (!", scope_str, ".ok()) return;");
|
||||
|
||||
// No outputs.
|
||||
if (op_def.output_arg_size() == 0) {
|
||||
if (graph_op_def.output_arg_size() == 0) {
|
||||
strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
|
||||
return;
|
||||
}
|
||||
if (op_def.output_arg_size() == 1) {
|
||||
if (graph_op_def.output_arg_size() == 1) {
|
||||
// One output, no need for NameRangeMap
|
||||
if (is_list_output[0]) {
|
||||
strings::StrAppend(out,
|
||||
@ -752,7 +861,7 @@ void OpInfo::GetOutput(string* out) const {
|
||||
".UpdateStatus(_status_);\n", " return;\n");
|
||||
strings::StrAppend(out, " }\n\n");
|
||||
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
|
||||
const string arg_range = strings::StrCat(
|
||||
"_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
|
||||
if (is_list_output[i]) {
|
||||
@ -776,11 +885,13 @@ string OpInfo::GetConstructorBody() const {
|
||||
|
||||
strings::StrAppend(&body, " ", return_on_error, "\n");
|
||||
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::",
|
||||
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(",
|
||||
scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n");
|
||||
for (int i = 0; i < graph_op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(graph_op_def.input_arg(i));
|
||||
const auto& api_def_arg(api_def.in_arg(i));
|
||||
strings::StrAppend(
|
||||
&body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::",
|
||||
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ",
|
||||
AvoidCPPKeywords(api_def_arg.rename_to()), ");\n");
|
||||
strings::StrAppend(&body, " ", return_on_error, "\n");
|
||||
}
|
||||
|
||||
@ -791,19 +902,21 @@ string OpInfo::GetConstructorBody() const {
|
||||
&body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
|
||||
graph_op_def.name(), "\")\n");
|
||||
const string spaces = " ";
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n");
|
||||
for (int i = 0; i < api_def.in_arg_size(); ++i) {
|
||||
const auto& arg(api_def.in_arg(i));
|
||||
strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n");
|
||||
}
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
for (int i = 0; i < api_def.attr_size(); ++i) {
|
||||
const auto& graph_attr(graph_op_def.attr(i));
|
||||
const auto& attr(op_def.attr(i));
|
||||
if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) {
|
||||
const auto& api_def_attr(api_def.attr(i));
|
||||
if (inferred_input_attrs.find(api_def_attr.name()) !=
|
||||
inferred_input_attrs.end()) {
|
||||
continue;
|
||||
}
|
||||
const string attr_name = attr.has_default_value()
|
||||
? strings::StrCat("attrs.", attr.name(), "_")
|
||||
: AvoidCPPKeywords(attr.name());
|
||||
const string attr_name =
|
||||
api_def_attr.has_default_value()
|
||||
? strings::StrCat("attrs.", api_def_attr.rename_to(), "_")
|
||||
: AvoidCPPKeywords(api_def_attr.rename_to());
|
||||
strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
|
||||
attr_name, ")\n");
|
||||
}
|
||||
@ -845,10 +958,10 @@ void OpInfo::WriteClassDef(WritableFile* cc) const {
|
||||
TF_CHECK_OK(cc->Append(class_def));
|
||||
}
|
||||
|
||||
void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def,
|
||||
void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def,
|
||||
const std::vector<string>& aliases, WritableFile* h,
|
||||
WritableFile* cc) {
|
||||
OpInfo op_info(graph_op_def, interface_op_def, aliases);
|
||||
OpInfo op_info(graph_op_def, api_def, aliases);
|
||||
|
||||
op_info.WriteClassDecl(h);
|
||||
op_info.WriteClassDef(cc);
|
||||
@ -943,8 +1056,9 @@ string MakeInternal(const string& fname) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
|
||||
const string& dot_cc_fname, const string& overrides_fnames) {
|
||||
void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||
const string& dot_h_fname, const string& dot_cc_fname,
|
||||
const string& overrides_fnames) {
|
||||
Env* env = Env::Default();
|
||||
|
||||
// Load the override map.
|
||||
@ -984,24 +1098,23 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname,
|
||||
// code depends on it.
|
||||
if (graph_op_def.name() == "Const") continue;
|
||||
|
||||
// Incorporate overrides from override_map.
|
||||
OpDef interface_op_def = graph_op_def;
|
||||
const OpGenOverride* op_override =
|
||||
override_map.ApplyOverride(&interface_op_def);
|
||||
std::vector<string> aliases;
|
||||
if (op_override) {
|
||||
if (op_override->skip()) continue;
|
||||
aliases.assign(op_override->alias().begin(), op_override->alias().end());
|
||||
if (op_override->hide()) {
|
||||
// Write hidden ops to _internal.h and _internal.cc.
|
||||
WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(),
|
||||
internal_cc.get());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
const auto* api_def = api_def_map.GetApiDef(graph_op_def.name());
|
||||
|
||||
std::vector<string> aliases;
|
||||
if (api_def->visibility() == ApiDef::SKIP) continue;
|
||||
// First endpoint is canonical, the rest are aliases.
|
||||
for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size();
|
||||
++endpoint_i) {
|
||||
aliases.push_back(api_def->endpoint(endpoint_i).name());
|
||||
}
|
||||
if (api_def->visibility() == ApiDef::HIDDEN) {
|
||||
// Write hidden ops to _internal.h and _internal.cc.
|
||||
WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(),
|
||||
internal_cc.get());
|
||||
continue;
|
||||
}
|
||||
// This isn't a hidden op, write it to the main files.
|
||||
WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get());
|
||||
WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get());
|
||||
}
|
||||
|
||||
FinishFiles(false, h.get(), cc.get(), op_header_guard);
|
||||
|
||||
@ -17,13 +17,15 @@ limitations under the License.
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// Result is written to files dot_h and dot_cc.
|
||||
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
|
||||
const string& dot_cc_fname, const string& overrides_fnames);
|
||||
void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||
const string& dot_h_fname, const string& dot_cc_fname,
|
||||
const string& overrides_fnames);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
@ -16,7 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -24,10 +28,28 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
|
||||
const std::string& overrides_fnames, bool include_internal) {
|
||||
const std::string& overrides_fnames, bool include_internal,
|
||||
const std::vector<string>& api_def_dirs) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(include_internal, &ops);
|
||||
WriteCCOps(ops, dot_h, dot_cc, overrides_fnames);
|
||||
ApiDefMap api_def_map(ops);
|
||||
if (!api_def_dirs.empty()) {
|
||||
Env* env = Env::Default();
|
||||
// Only load files that correspond to "ops".
|
||||
for (const auto& op : ops.op()) {
|
||||
for (const auto& api_def_dir : api_def_dirs) {
|
||||
const std::string api_def_file_pattern =
|
||||
io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
|
||||
if (env->FileExists(api_def_file_pattern).ok()) {
|
||||
TF_CHECK_OK(api_def_map.LoadFile(env, api_def_file_pattern));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
api_def_map.UpdateDocs();
|
||||
|
||||
WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -35,18 +57,24 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
if (argc != 5) {
|
||||
// TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt
|
||||
// as an argument.
|
||||
if (argc != 6) {
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
fprintf(stderr, "Arg %d = %s\n", i, argv[i]);
|
||||
}
|
||||
fprintf(stderr,
|
||||
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n"
|
||||
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal "
|
||||
"api_def_dirs1,api_def_dir2 ...\n"
|
||||
" include_internal: 1 means include internal ops\n",
|
||||
argv[0]);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool include_internal = tensorflow::StringPiece("1") == argv[4];
|
||||
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal);
|
||||
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
|
||||
argv[5], ",", tensorflow::str_util::SkipEmpty());
|
||||
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal,
|
||||
api_def_dirs);
|
||||
return 0;
|
||||
}
|
||||
|
||||
195
tensorflow/cc/framework/cc_op_gen_test.cc
Normal file
195
tensorflow/cc/framework/cc_op_gen_test.cc
Normal file
@ -0,0 +1,195 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// TODO(annarev): Remove this op_gen_overrides.pbtxt reference.
|
||||
// It is needed only because WriteCCOps takes it as an argument.
|
||||
constexpr char kOverridesFnames[] =
|
||||
"tensorflow/cc/ops/op_gen_overrides.pbtxt";
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "Foo"
|
||||
input_arg {
|
||||
name: "images"
|
||||
description: "Images to process."
|
||||
}
|
||||
input_arg {
|
||||
name: "dim"
|
||||
description: "Description for dim."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Description for output."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
description: "Type for images"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_INT8
|
||||
}
|
||||
}
|
||||
default_value {
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
summary: "Summary for op Foo."
|
||||
description: "Description for op Foo."
|
||||
}
|
||||
)";
|
||||
|
||||
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_TRUE(s.contains(expected))
|
||||
<< "'" << s << "' does not contain '" << expected << "'";
|
||||
}
|
||||
|
||||
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_FALSE(s.contains(expected))
|
||||
<< "'" << s << "' contains '" << expected << "'";
|
||||
}
|
||||
|
||||
void ExpectSubstrOrder(const string& s, const string& before,
|
||||
const string& after) {
|
||||
int before_pos = s.find(before);
|
||||
int after_pos = s.find(after);
|
||||
ASSERT_NE(std::string::npos, before_pos);
|
||||
ASSERT_NE(std::string::npos, after_pos);
|
||||
EXPECT_LT(before_pos, after_pos)
|
||||
<< before << " is not before " << after << " in " << s;
|
||||
}
|
||||
|
||||
// Runs WriteCCOps and stores output in (internal_)cc_file_path and
|
||||
// (internal_)h_file_path.
|
||||
void GenerateCcOpFiles(Env* env, const OpList& ops,
|
||||
const ApiDefMap& api_def_map, string* h_file_text,
|
||||
string* internal_h_file_text) {
|
||||
const string& tmpdir = testing::TmpDir();
|
||||
|
||||
const auto h_file_path = io::JoinPath(tmpdir, "test.h");
|
||||
const auto cc_file_path = io::JoinPath(tmpdir, "test.cc");
|
||||
const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h");
|
||||
const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc");
|
||||
|
||||
WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames);
|
||||
|
||||
TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text));
|
||||
TF_ASSERT_OK(
|
||||
ReadFileToString(env, internal_h_file_path, internal_h_file_text));
|
||||
}
|
||||
|
||||
TEST(CcOpGenTest, TestVisibilityChangedToHidden) {
|
||||
const string api_def = R"(
|
||||
op {
|
||||
graph_op_name: "Foo"
|
||||
visibility: HIDDEN
|
||||
}
|
||||
)";
|
||||
Env* env = Env::Default();
|
||||
OpList op_defs;
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
string h_file_text, internal_h_file_text;
|
||||
// Without ApiDef
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectHasSubstr(h_file_text, "class Foo");
|
||||
ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo");
|
||||
|
||||
// With ApiDef
|
||||
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectHasSubstr(internal_h_file_text, "class Foo");
|
||||
ExpectDoesNotHaveSubstr(h_file_text, "class Foo");
|
||||
}
|
||||
|
||||
TEST(CcOpGenTest, TestArgNameChanges) {
|
||||
const string api_def = R"(
|
||||
op {
|
||||
graph_op_name: "Foo"
|
||||
arg_order: "dim"
|
||||
arg_order: "images"
|
||||
}
|
||||
)";
|
||||
Env* env = Env::Default();
|
||||
OpList op_defs;
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
|
||||
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
string cc_file_text, h_file_text;
|
||||
string internal_cc_file_text, internal_h_file_text;
|
||||
// Without ApiDef
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectSubstrOrder(h_file_text, "Input images", "Input dim");
|
||||
|
||||
// With ApiDef
|
||||
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectSubstrOrder(h_file_text, "Input dim", "Input images");
|
||||
}
|
||||
|
||||
TEST(CcOpGenTest, TestEndpoints) {
|
||||
const string api_def = R"(
|
||||
op {
|
||||
graph_op_name: "Foo"
|
||||
endpoint {
|
||||
name: "Foo1"
|
||||
}
|
||||
endpoint {
|
||||
name: "Foo2"
|
||||
}
|
||||
}
|
||||
)";
|
||||
Env* env = Env::Default();
|
||||
OpList op_defs;
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
|
||||
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
string cc_file_text, h_file_text;
|
||||
string internal_cc_file_text, internal_h_file_text;
|
||||
// Without ApiDef
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectHasSubstr(h_file_text, "class Foo {");
|
||||
ExpectDoesNotHaveSubstr(h_file_text, "class Foo1");
|
||||
ExpectDoesNotHaveSubstr(h_file_text, "class Foo2");
|
||||
|
||||
// With ApiDef
|
||||
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectHasSubstr(h_file_text, "class Foo1");
|
||||
ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2");
|
||||
ExpectDoesNotHaveSubstr(h_file_text, "class Foo {");
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
@ -21,6 +21,9 @@ namespace tensorflow {
|
||||
/// Tag for the `gpu` graph.
|
||||
constexpr char kSavedModelTagGpu[] = "gpu";
|
||||
|
||||
/// Tag for the `tpu` graph.
|
||||
constexpr char kSavedModelTagTpu[] = "tpu";
|
||||
|
||||
/// Tag for the `serving` graph.
|
||||
constexpr char kSavedModelTagServe[] = "serve";
|
||||
|
||||
|
||||
@ -24,10 +24,12 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class FunctionTest(XLATestCase):
|
||||
|
||||
def testFunction(self):
|
||||
|
||||
@ -623,11 +623,12 @@ class FunctionalizeCond {
|
||||
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library)
|
||||
: clusters_(graph->num_node_ids()), library_(library), graph_(graph) {}
|
||||
|
||||
// Returns a vector of Merge nodes from the clustered graph where the nodes
|
||||
// Returns a vector of Switch nodes from the clustered graph where the nodes
|
||||
// are sorted by the number of switch nodes minus number of merge nodes
|
||||
// from a root of the clustered graph to the given Merge node, with ties
|
||||
// broken by the representative of the Cluster.
|
||||
std::vector<std::pair<int, Cluster*>> SortedMergeNodes();
|
||||
// broken by the representative of the Cluster. This corresponds to sorting by
|
||||
// nesting depth, from deepest nested to outermost.
|
||||
std::vector<std::pair<int, Cluster*>> SortedSwitchNodes();
|
||||
|
||||
// Returns whether the graph has no conditionals.
|
||||
bool NoConditionals() const { return merge_nodes_.empty(); }
|
||||
@ -654,15 +655,17 @@ class FunctionalizeCond {
|
||||
// extracting the bodies needed for the then and else branch, creates a XlaIf
|
||||
// node, removing the nodes of the branches from the graph and replacing the
|
||||
// merge node with a XlaIf.
|
||||
Status ConvertMergeToXlaIf(Cluster* merge_cluster);
|
||||
Status ConvertCorrespondingMergeToXlaIf(Cluster* switch_cluster);
|
||||
|
||||
// Removes a Switch cluster feeding directly into a Merge cluster by removing
|
||||
// the Switch and Merge nodes and collapsing into a single cluster.
|
||||
Status RemoveTrivialMerge(Cluster* merge_cluster);
|
||||
Status RemoveTrivialSwitch(Cluster* switch_cluster);
|
||||
|
||||
// Returns the switch cluster corresponding to the merge node. This function
|
||||
// only returns the switch cluster in the simple case where we have a switch
|
||||
// node is the entry of a diamond corresponding to a conditional:
|
||||
// Returns the merge cluster corresponding to the switch node. This function
|
||||
// only returns the merge cluster in the case where we have a switch node that
|
||||
// is the single entry point for all paths to a common merge cluster, this
|
||||
// merge cluster may be created by combining multiple merge clusters, that
|
||||
// share the switch cluster as common ancestor, together.
|
||||
//
|
||||
// Switch
|
||||
// / \
|
||||
@ -671,8 +674,9 @@ class FunctionalizeCond {
|
||||
// merge_cluster
|
||||
//
|
||||
// Note: either of the branches may be empty. The case where both branches are
|
||||
// empty is handled by RemoveTrivialMerge.
|
||||
gtl::optional<Cluster*> GetSwitchCluster(const Cluster& merge_cluster);
|
||||
// empty is handled by RemoveTrivialSwitch.
|
||||
gtl::optional<Cluster*> CreateCorrespondingMergeCluster(
|
||||
Cluster* switch_cluster);
|
||||
|
||||
// Determines the arguments needed as input to the Merge cluster originating
|
||||
// from the Switch cluster.
|
||||
@ -793,6 +797,10 @@ bool IsDeadSwitch(const Node* node) {
|
||||
}
|
||||
|
||||
void FunctionalizeCond::CreateClusters() {
|
||||
ClusterHandle source_cluster = ClusterHandle(Graph::kSourceId);
|
||||
auto& source = clusters_.at(source_cluster);
|
||||
std::deque<std::pair<ClusterHandle, std::deque<Node*>>> workqueue;
|
||||
workqueue.push_back({source_cluster, {}});
|
||||
for (Node* node : graph_->nodes()) {
|
||||
if (IsSwitch(node)) {
|
||||
switch_nodes_.insert(node);
|
||||
@ -801,6 +809,12 @@ void FunctionalizeCond::CreateClusters() {
|
||||
}
|
||||
ClusterHandle& cluster = clusters_.at(node).Get();
|
||||
cluster = ClusterHandle(node->id());
|
||||
// Group all source clusters together.
|
||||
if (node->IsSource() || node->in_edges().empty()) {
|
||||
clusters_.at(node).Merge(&source);
|
||||
source.Merge(&clusters_.at(node));
|
||||
workqueue.front().second.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no Merge nodes, then terminate.
|
||||
@ -815,20 +829,118 @@ void FunctionalizeCond::CreateClusters() {
|
||||
// conservatively assuming all merge nodes become XlaIf nodes.
|
||||
clusters_.resize(clusters_.size() + merge_nodes_.size());
|
||||
|
||||
// Merge a cluster with its input, unless the input is a Switch node or
|
||||
// the node is a Merge node.
|
||||
for (const Node* node : graph_->nodes()) {
|
||||
if (IsMerge(node) || IsSwitch(node) || !node->IsOp()) {
|
||||
continue;
|
||||
}
|
||||
for (const Node* in : node->in_nodes()) {
|
||||
if (in->IsOp() && !IsSwitch(in) && !IsMerge(in)) {
|
||||
clusters_.at(node).Merge(&clusters_.at(in));
|
||||
std::unordered_set<Node*> marked;
|
||||
while (!workqueue.empty()) {
|
||||
auto cluster_queue = workqueue.front();
|
||||
VLOG(4) << "Cluster: " << cluster_queue.first << " Queue: {"
|
||||
<< str_util::Join(cluster_queue.second, ",",
|
||||
[](string* output, const Node* node) {
|
||||
strings::StrAppend(output, node->id());
|
||||
})
|
||||
<< "}";
|
||||
|
||||
UnionFind<ClusterHandle>& repr = clusters_.at(cluster_queue.first);
|
||||
workqueue.pop_front();
|
||||
std::deque<Node*> switch_nodes;
|
||||
std::deque<Node*> merge_nodes;
|
||||
std::unordered_set<Node*> cluster_member;
|
||||
while (!cluster_queue.second.empty()) {
|
||||
// Iterate node workqueue and flow forward merging all nodes reachable
|
||||
// that are neither a Switch or a Merge and whose inputs are all part of
|
||||
// the same cluster.
|
||||
Node* cur = cluster_queue.second.front();
|
||||
cluster_queue.second.pop_front();
|
||||
if (marked.find(cur) != marked.end()) {
|
||||
continue;
|
||||
}
|
||||
if (IsMerge(cur)) {
|
||||
merge_nodes.push_back(cur);
|
||||
marked.insert(cur);
|
||||
continue;
|
||||
}
|
||||
if (IsSwitch(cur)) {
|
||||
switch_nodes.push_back(cur);
|
||||
marked.insert(cur);
|
||||
continue;
|
||||
}
|
||||
clusters_.at(cur).Merge(&repr);
|
||||
cluster_member.insert(cur);
|
||||
for (Node* out : cur->out_nodes()) {
|
||||
bool all_ancestors_in_cluster = true;
|
||||
for (Node* in : out->in_nodes()) {
|
||||
if (IsMerge(out)) {
|
||||
merge_nodes.push_back(out);
|
||||
}
|
||||
if (IsSwitch(out)) {
|
||||
switch_nodes.push_back(out);
|
||||
}
|
||||
if (cluster_member.find(in) == cluster_member.end()) {
|
||||
all_ancestors_in_cluster = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_ancestors_in_cluster && out->IsOp()) {
|
||||
cluster_queue.second.push_back(out);
|
||||
marked.insert(cur);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Group all source clusters together.
|
||||
if (node->IsSource() || node->in_edges().empty()) {
|
||||
clusters_.at(node).Merge(&clusters_.at(ClusterHandle(Graph::kSourceId)));
|
||||
|
||||
VLOG(4) << "Switches: {"
|
||||
<< str_util::Join(switch_nodes, ",",
|
||||
[](string* output, const Node* node) {
|
||||
strings::StrAppend(output, node->id());
|
||||
})
|
||||
<< "}";
|
||||
|
||||
// Merge Switch nodes with common predicate.
|
||||
std::unordered_map<Node*, std::vector<Node*>> predicate_to_switch;
|
||||
for (Node* node : switch_nodes) {
|
||||
Node* tmp;
|
||||
TF_CHECK_OK(node->input_node(1, &tmp));
|
||||
predicate_to_switch[tmp].push_back(node);
|
||||
}
|
||||
for (auto kv : predicate_to_switch) {
|
||||
Node* first = kv.second.front();
|
||||
for (Node* switch_node : kv.second) {
|
||||
clusters_.at(first).Merge(&clusters_.at(switch_node));
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue each edge of the switch node separately. That is, group all the
|
||||
// nodes that are due to the true/false edge of the switch together and
|
||||
// consider all nodes that only have a control dependency on the switch node
|
||||
// separately. We want to group together all nodes that are part of the same
|
||||
// branch, as these will be extracted into the `then` and `else` functions
|
||||
// of the functional if. The ops due to control edges are different as they
|
||||
// could be involved with either branch and merging them here could result
|
||||
// in invalid graphs.
|
||||
for (auto kv : predicate_to_switch) {
|
||||
ClusterHandle none = ClusterHandle(-1);
|
||||
ClusterHandle first[2] = {none, none};
|
||||
std::deque<Node*>* queue[2];
|
||||
for (auto switch_node : kv.second) {
|
||||
for (const auto e : switch_node->out_edges()) {
|
||||
if (IsSwitch(e->dst()) || IsMerge(e->dst())) {
|
||||
continue;
|
||||
}
|
||||
// Control edges are enqueued on their own.
|
||||
if (e->IsControlEdge()) {
|
||||
workqueue.push_back({Representative(e->dst()), {e->dst()}});
|
||||
continue;
|
||||
}
|
||||
// Combine all outputs of the same output port of a switch cluster
|
||||
// into the same workqueue entry.
|
||||
if (first[e->src_output()] == none) {
|
||||
ClusterHandle repr = Representative(e->dst());
|
||||
first[e->src_output()] = repr;
|
||||
workqueue.push_back({repr, {}});
|
||||
queue[e->src_output()] = &workqueue.back().second;
|
||||
}
|
||||
clusters_.at(first[e->src_output()]).Merge(&clusters_.at(e->dst()));
|
||||
queue[e->src_output()]->push_back(e->dst());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -910,74 +1022,60 @@ void FunctionalizeCond::CreateClusteredGraph() {
|
||||
update_cluster_for_node(node).merge_nodes.insert(node);
|
||||
}
|
||||
|
||||
// Merge Switch nodes with common predicate.
|
||||
std::unordered_map<Node*, std::vector<Node*>> predicate_to_switch;
|
||||
for (Node* node : switch_nodes_) {
|
||||
Node* tmp;
|
||||
TF_CHECK_OK(node->input_node(1, &tmp));
|
||||
predicate_to_switch[tmp].push_back(node);
|
||||
}
|
||||
for (auto kv : predicate_to_switch) {
|
||||
Cluster& first = clustered_graph_.at(Representative(kv.second.front()));
|
||||
for (Node* switch_node : kv.second) {
|
||||
ClusterHandle handle = Representative(switch_node);
|
||||
Cluster& cluster = clustered_graph_.at(handle);
|
||||
ContractEdge(&cluster, &first, /*remove_from_graph=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
// Merge Merge nodes with common input together.
|
||||
for (Node* node : merge_nodes_) {
|
||||
Cluster& cluster = clustered_graph_.at(Representative(node));
|
||||
for (const Node* in : node->in_nodes()) {
|
||||
if (!in->IsOp()) {
|
||||
continue;
|
||||
}
|
||||
Cluster& cluster_node_in = clustered_graph_.at(Representative(in));
|
||||
// ContractEdge can modify out_nodes of cluster_node_in, so traverse
|
||||
// over out_nodes assuming it does.
|
||||
for (auto it = cluster_node_in.out_nodes.begin();
|
||||
it != cluster_node_in.out_nodes.end();) {
|
||||
if (!(*it)->merge_nodes.empty()) {
|
||||
ContractEdge(*it++, &cluster, /*remove_from_graph=*/true);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_);
|
||||
VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_);
|
||||
}
|
||||
|
||||
gtl::optional<FunctionalizeCond::Cluster*> FunctionalizeCond::GetSwitchCluster(
|
||||
const Cluster& merge_cluster) {
|
||||
VLOG(3) << "GetSwitchCluster for " << merge_cluster.representative;
|
||||
gtl::optional<Cluster*> switch_cluster;
|
||||
if (merge_cluster.in_nodes.size() > 2) {
|
||||
return gtl::nullopt;
|
||||
gtl::optional<FunctionalizeCond::Cluster*>
|
||||
FunctionalizeCond::CreateCorrespondingMergeCluster(Cluster* switch_cluster) {
|
||||
VLOG(3) << "CreateCorrespondingMergeCluster for "
|
||||
<< switch_cluster->representative;
|
||||
std::unordered_set<Cluster*> merges;
|
||||
std::unordered_set<Cluster*> dominated;
|
||||
dominated.insert(switch_cluster);
|
||||
std::deque<Cluster*> queue;
|
||||
auto enqueue_or_update_merge = [this, &queue, &merges](Cluster* c) {
|
||||
if (c->merge_nodes.empty()) {
|
||||
queue.push_back(c);
|
||||
} else {
|
||||
merges.insert(c);
|
||||
}
|
||||
};
|
||||
// Enqueue all the outputs of the switch cluster in the workqueue.
|
||||
for (auto* out : switch_cluster->out_nodes) {
|
||||
enqueue_or_update_merge(out);
|
||||
}
|
||||
for (Cluster* in : merge_cluster.in_nodes) {
|
||||
Cluster* cluster = in;
|
||||
if (in->switch_nodes.empty()) {
|
||||
if (in->in_nodes.size() != 1 || in->out_nodes.size() != 1) {
|
||||
std::unordered_set<Cluster*> visited;
|
||||
while (!queue.empty()) {
|
||||
Cluster* cur = queue.front();
|
||||
queue.pop_front();
|
||||
if (visited.find(cur) != visited.end()) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(cur);
|
||||
// Ensure all inputs to the current node are in the dominated set.
|
||||
for (Cluster* in : cur->in_nodes) {
|
||||
if (dominated.find(in) == dominated.end()) {
|
||||
return gtl::nullopt;
|
||||
}
|
||||
// There is only a single `in` cluster.
|
||||
cluster = *in->in_nodes.begin();
|
||||
}
|
||||
if (cluster->switch_nodes.empty()) {
|
||||
return gtl::nullopt;
|
||||
}
|
||||
|
||||
if (switch_cluster.has_value() && *switch_cluster != cluster) {
|
||||
return gtl::nullopt;
|
||||
} else {
|
||||
switch_cluster = cluster;
|
||||
for (Cluster* out : cur->out_nodes) {
|
||||
// No switch nodes beyond the entry one is expected.
|
||||
if (!out->switch_nodes.empty()) {
|
||||
return gtl::nullopt;
|
||||
}
|
||||
enqueue_or_update_merge(out);
|
||||
}
|
||||
}
|
||||
return switch_cluster;
|
||||
auto it = merges.begin();
|
||||
Cluster* merge_cluster = *it;
|
||||
for (++it; it != merges.end(); ++it) {
|
||||
ContractEdge(*it, merge_cluster);
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Clean up graph, merging nodes.
|
||||
|
||||
return merge_cluster;
|
||||
}
|
||||
|
||||
xla::StatusOr<FunctionalizeCond::CondArgs> FunctionalizeCond::DetermineCondArgs(
|
||||
@ -1221,11 +1319,11 @@ void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) {
|
||||
}
|
||||
}
|
||||
|
||||
Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) {
|
||||
Cluster* switch_cluster = *merge_cluster->in_nodes.begin();
|
||||
if (switch_cluster->switch_nodes.empty()) {
|
||||
Status FunctionalizeCond::RemoveTrivialSwitch(Cluster* switch_cluster) {
|
||||
Cluster* merge_cluster = *switch_cluster->out_nodes.begin();
|
||||
if (merge_cluster->merge_nodes.empty()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Not a trivial merge: no Switch node feeding into Merge node");
|
||||
"Not a trivial switch: no Merge node feeding into Switch node");
|
||||
}
|
||||
|
||||
for (auto it = merge_cluster->merge_nodes.begin();
|
||||
@ -1252,17 +1350,25 @@ Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) {
|
||||
VLOG(1) << "ConvertMergeToXlaIf for " << merge_cluster->representative;
|
||||
gtl::optional<Cluster*> switch_cluster = GetSwitchCluster(*merge_cluster);
|
||||
if (!switch_cluster.has_value()) {
|
||||
Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf(
|
||||
Cluster* switch_cluster) {
|
||||
VLOG(1) << "ConvertMergeToXlaIf for " << switch_cluster->representative;
|
||||
gtl::optional<Cluster*> maybe_merge =
|
||||
CreateCorrespondingMergeCluster(switch_cluster);
|
||||
if (!maybe_merge.has_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Merge cluster was not part of a simple conditional in the clustered "
|
||||
"graph. Graph nodes in merge cluster ",
|
||||
NodesToString(merge_cluster->merge_nodes));
|
||||
"Switch cluster was not part of a simple conditional in the clustered "
|
||||
"graph. Graph nodes in switch cluster ",
|
||||
NodesToString(switch_cluster->switch_nodes));
|
||||
}
|
||||
Cluster* merge_cluster = *maybe_merge;
|
||||
if (merge_cluster->merge_nodes.empty()) {
|
||||
return errors::Internal(
|
||||
"Merge node in clustered graph contains no merge nodes: ",
|
||||
merge_cluster->representative.ToString());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto cond_args,
|
||||
DetermineCondArgs(*merge_cluster, **switch_cluster));
|
||||
DetermineCondArgs(*merge_cluster, *switch_cluster));
|
||||
|
||||
// Sort the outputs by ID to produce more stable output.
|
||||
std::vector<Node*> outputs(merge_cluster->merge_nodes.begin(),
|
||||
@ -1278,7 +1384,7 @@ Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) {
|
||||
// Remove the old nodes from the graph_ and contract the edges of the
|
||||
// clustered graph.
|
||||
for (auto in : merge_cluster->in_nodes) {
|
||||
if (in != *switch_cluster) {
|
||||
if (in != switch_cluster) {
|
||||
RemoveClusterNodes(in);
|
||||
}
|
||||
}
|
||||
@ -1286,20 +1392,20 @@ Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) {
|
||||
RemoveUnusedArgs(cond_args.args);
|
||||
auto in_nodes = merge_cluster->in_nodes;
|
||||
for (auto it = in_nodes.begin(); it != in_nodes.end();) {
|
||||
ContractEdge(*it++, merge_cluster);
|
||||
ContractEdge(*it++, switch_cluster);
|
||||
}
|
||||
ContractEdge(*switch_cluster, merge_cluster);
|
||||
clusters_[if_node].Get() = ClusterHandle(merge_cluster->representative);
|
||||
ContractEdge(merge_cluster, switch_cluster);
|
||||
clusters_[if_node].Get() = ClusterHandle(switch_cluster->representative);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::pair<int, FunctionalizeCond::Cluster*>>
|
||||
FunctionalizeCond::SortedMergeNodes() {
|
||||
FunctionalizeCond::SortedSwitchNodes() {
|
||||
VLOG(2) << "ProcessClusteredGraph";
|
||||
std::stack<std::pair<int, Cluster*>> stack;
|
||||
// Initialize with the source node.
|
||||
stack.push({0, &clustered_graph_[ClusterHandle(Graph::kSourceId)]});
|
||||
stack.push({0, &clustered_graph_[Representative(graph_->source_node())]});
|
||||
|
||||
// Perform a depth-first traversal of the clustered graph computing the
|
||||
// switch-merge depth.
|
||||
@ -1317,10 +1423,10 @@ FunctionalizeCond::SortedMergeNodes() {
|
||||
|
||||
size_t new_depth = depth;
|
||||
if (!n->merge_nodes.empty()) {
|
||||
queue.emplace_back(depth, n);
|
||||
--new_depth;
|
||||
}
|
||||
if (!n->switch_nodes.empty()) {
|
||||
queue.emplace_back(depth, n);
|
||||
++new_depth;
|
||||
}
|
||||
for (Cluster* e : n->out_nodes) {
|
||||
@ -1350,25 +1456,30 @@ Status FunctionalizeCond::Functionalize(Graph* graph,
|
||||
}
|
||||
fc.CreateClusteredGraph();
|
||||
|
||||
auto queue = fc.SortedMergeNodes();
|
||||
auto queue = fc.SortedSwitchNodes();
|
||||
for (auto it = queue.begin(); it != queue.end();) {
|
||||
Cluster* merge_cluster = (*it).second;
|
||||
Cluster* switch_cluster = (*it).second;
|
||||
++it;
|
||||
if (merge_cluster->in_nodes.size() == 1) {
|
||||
TF_RETURN_IF_ERROR(fc.RemoveTrivialMerge(merge_cluster));
|
||||
if (switch_cluster->out_nodes.size() == 1) {
|
||||
TF_RETURN_IF_ERROR(fc.RemoveTrivialSwitch(switch_cluster));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(fc.ConvertMergeToXlaIf(merge_cluster));
|
||||
TF_RETURN_IF_ERROR(fc.ConvertCorrespondingMergeToXlaIf(switch_cluster));
|
||||
}
|
||||
|
||||
// Contract newly Merge free merge_cluster with incoming nodes without
|
||||
// Contract newly Switch free switch_cluster with outgoing nodes without
|
||||
// Switch or Merge nodes.
|
||||
std::vector<Cluster*> in_nodes(merge_cluster->in_nodes.begin(),
|
||||
merge_cluster->in_nodes.end());
|
||||
for (auto in : in_nodes) {
|
||||
if (in->merge_nodes.empty() && in->switch_nodes.empty()) {
|
||||
fc.ContractEdge(in, merge_cluster);
|
||||
for (auto& nodes : {switch_cluster->out_nodes, switch_cluster->in_nodes}) {
|
||||
std::vector<Cluster*> copy_nodes(nodes.begin(), nodes.end());
|
||||
for (auto* node : copy_nodes) {
|
||||
if (node->merge_nodes.empty() && node->switch_nodes.empty()) {
|
||||
fc.ContractEdge(node, switch_cluster);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(3) << "Graph with clusters: "
|
||||
<< DebugString(*fc.graph_, &fc.clusters_);
|
||||
VLOG(3) << "ClusteredGraph: " << DebugString(fc.clustered_graph_);
|
||||
}
|
||||
|
||||
if (!fc.switch_nodes_.empty()) {
|
||||
|
||||
@ -153,6 +153,7 @@ bool ComputationBuilder::MakeWindow(
|
||||
} else {
|
||||
dim->set_window_dilation(1);
|
||||
}
|
||||
dim->set_window_reversal(false);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -1163,6 +1164,34 @@ ComputationDataHandle ComputationBuilder::ConvertElementType(
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::BitcastConvertType(
|
||||
const ComputationDataHandle& operand, PrimitiveType new_element_type) {
|
||||
if (!first_error_.ok() || !PrepareComputation().ok()) {
|
||||
return ComputationDataHandle();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
|
||||
if (!shape_status.ok()) {
|
||||
first_error_ = shape_status.status();
|
||||
return ComputationDataHandle();
|
||||
}
|
||||
std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
|
||||
|
||||
ConvertRequest request;
|
||||
*request.mutable_operand() = operand;
|
||||
request.set_new_element_type(new_element_type);
|
||||
OpRequest op_request;
|
||||
*op_request.mutable_computation() = computation_.handle();
|
||||
*op_request.mutable_bitcast_convert_request() = request;
|
||||
AddCommonFieldsToOpRequest(&op_request);
|
||||
OpResponse response;
|
||||
|
||||
VLOG(2) << "making bitcast convert request";
|
||||
Status s = client_->stub()->Op(&op_request, &response);
|
||||
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::SquareF32(
|
||||
const ComputationDataHandle& operand) {
|
||||
return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0),
|
||||
|
||||
@ -121,14 +121,10 @@ class ComputationBuilder {
|
||||
// result, OpMetadata is set on the Computation Builder. All subsequent
|
||||
// instructions generated via this Computation Builder will have the same
|
||||
// OpMetadata attached until a call to ClearOpMetdata.
|
||||
void SetOpMetadata(const OpMetadata& metadata) {
|
||||
metadata_ = metadata;
|
||||
}
|
||||
void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
|
||||
|
||||
// Clears the HloMetadata state.
|
||||
void ClearOpMetadata() {
|
||||
metadata_.Clear();
|
||||
}
|
||||
void ClearOpMetadata() { metadata_.Clear(); }
|
||||
|
||||
// Sets an OpSharding that will be attached to all instructions until cleared.
|
||||
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
|
||||
@ -673,6 +669,13 @@ class ComputationBuilder {
|
||||
ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
|
||||
PrimitiveType new_element_type);
|
||||
|
||||
// Enqueues a no-op instruction onto the computation that changes
|
||||
// the element type of the operand array to primitive_type. The
|
||||
// bit-widths of the source and destination element types must be
|
||||
// identical.
|
||||
ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand,
|
||||
PrimitiveType new_element_type);
|
||||
|
||||
// Enqueues a float32 reciprocal instruction onto the computation.
|
||||
// (float32 is specified as there is an implicit float32 -1.0f constant
|
||||
// exponent).
|
||||
|
||||
@ -275,9 +275,6 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
|
||||
device_ordinal, options));
|
||||
}
|
||||
|
||||
// Copy the literal data to the device with the given ordinal and return as a
|
||||
// ScopedShapedBuffer. The given memory allocator is used for device memory
|
||||
// allocation.
|
||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>>
|
||||
LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal,
|
||||
DeviceMemoryAllocator* allocator) {
|
||||
@ -298,8 +295,6 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal,
|
||||
return std::move(scoped_buffer);
|
||||
}
|
||||
|
||||
// Copy the data from the device contained in the given ShapedBuffer and
|
||||
// return as a Literal.
|
||||
StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
|
||||
const ShapedBuffer& shaped_buffer) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -309,4 +304,22 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
|
||||
shaped_buffer);
|
||||
}
|
||||
|
||||
Status LocalClient::TransferToInfeedLocal(const Literal& literal,
|
||||
int device_ordinal) {
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
backend().stream_executor(device_ordinal));
|
||||
return backend().transfer_manager()->TransferLiteralToInfeed(executor,
|
||||
literal);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
|
||||
const Shape& shape, int device_ordinal) {
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
backend().stream_executor(device_ordinal));
|
||||
auto literal = MakeUnique<Literal>();
|
||||
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, shape, literal.get()));
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -162,6 +162,20 @@ class LocalClient : public Client {
|
||||
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
|
||||
const ShapedBuffer& shaped_buffer);
|
||||
|
||||
// Transfer the given literal to the infeed queue of the given device.
|
||||
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
|
||||
// not inherit from Client and there is no possibility of confusion with
|
||||
// Client::TransferToInfeed.
|
||||
Status TransferToInfeedLocal(const Literal& literal, int device_ordinal);
|
||||
|
||||
// Transfer and return a value of the given shape from the outfeed of the
|
||||
// given device.
|
||||
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
|
||||
// not inherit from Client and there is no possibility of confusion with
|
||||
// Client::TransferFromOutfeed.
|
||||
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
|
||||
const Shape& shape, int device_ordinal);
|
||||
|
||||
// Returns the platform that the underlying service targets.
|
||||
perftools::gputools::Platform* platform() const;
|
||||
|
||||
|
||||
@ -1304,6 +1304,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1638,10 +1639,14 @@ cc_library(
|
||||
deps = [
|
||||
":buffer_liveness",
|
||||
":hlo",
|
||||
":hlo_alias_analysis",
|
||||
":hlo_dce",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_ordering",
|
||||
":hlo_pass",
|
||||
":liveness_util",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
":tuple_simplifier",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -1656,15 +1661,17 @@ tf_cc_test(
|
||||
deps = [
|
||||
":copy_insertion",
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_matchers",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -1108,9 +1108,15 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
if (IsAll(rhs, -1)) {
|
||||
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::One(rhs->shape().element_type()).CloneToUnique()));
|
||||
|
||||
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
|
||||
// broadcast in divide HLO as we are trying to eliminate implicit
|
||||
// broadcasting at HLO level.
|
||||
auto* broadcast_one = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(power->shape(), one, {}));
|
||||
return ReplaceWithNewInstruction(
|
||||
power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
|
||||
one, lhs));
|
||||
broadcast_one, lhs));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1398,6 +1404,15 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
auto operand = reduce_window->mutable_operand(0);
|
||||
const Window& window = reduce_window->window();
|
||||
auto function = reduce_window->to_apply();
|
||||
if (ShapeUtil::IsScalar(operand->shape())) {
|
||||
TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape()));
|
||||
return ReplaceWithNewInstruction(
|
||||
reduce_window,
|
||||
HloInstruction::CreateMap(reduce_window->shape(),
|
||||
{operand, reduce_window->mutable_operand(1)},
|
||||
function));
|
||||
}
|
||||
|
||||
VLOG(10) << "Considering folding Pad: " << operand->ToString()
|
||||
<< "\ninto reduce-window: " << reduce_window->ToString();
|
||||
|
||||
|
||||
@ -761,8 +761,10 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Divide(op::Constant(), param0));
|
||||
EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
|
||||
EXPECT_THAT(root, op::Divide(op::Broadcast(), param0));
|
||||
EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast);
|
||||
EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement<float>(),
|
||||
1);
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
|
||||
|
||||
@ -1265,7 +1265,6 @@ const LogicalBuffer* AddBufferToColocatedSet(
|
||||
// CopyInsertion ensures root points-to set is unambiguous and distinct.
|
||||
const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
|
||||
DCHECK(!points_to.IsAmbiguous());
|
||||
DCHECK(points_to.IsDistinct());
|
||||
colocated_set->push_back(points_to.element(index)[0]);
|
||||
return colocated_set->back();
|
||||
}
|
||||
|
||||
@ -1538,8 +1538,6 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
auto output0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
auto output1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
|
||||
auto cond0 =
|
||||
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
|
||||
@ -1556,10 +1554,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
|
||||
auto body1 =
|
||||
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
|
||||
|
||||
auto tuple1 = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({input0, weights0, output1}));
|
||||
auto while1 = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
|
||||
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
RunCopyInsertion(module.get());
|
||||
@ -1676,11 +1672,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
|
||||
auto while1 = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
|
||||
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
|
||||
auto root_add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
while0->shape(), HloOpcode::kAdd, while0, while1));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
while0->shape(), HloOpcode::kAdd, gte0, gte1));
|
||||
|
||||
RunCopyInsertion(module.get());
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
{
|
||||
FlattenCallGraph flatten;
|
||||
@ -1688,22 +1687,22 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
RunCopyInsertion(module.get());
|
||||
|
||||
auto sequence =
|
||||
CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
|
||||
|
||||
// To trigger b/38494731, we want a specific Hlo sequence for the
|
||||
// root computation, so we overwrite that entry with a manually
|
||||
// crafted sequence.
|
||||
std::vector<const HloInstruction*> sequence_for_buffer_assigment = {
|
||||
input1, weights1, one, output1, tuple1, while1, input0,
|
||||
weights0, zero, output0, tuple0, while0, root_add};
|
||||
sequence[module->entry_computation()] = {
|
||||
input1, weights1, one, output1, while1->operand(0), while1,
|
||||
input0, weights0, zero, output0, while0->operand(0), while0,
|
||||
gte0, gte1, root_add};
|
||||
|
||||
// If this ASSERT_TRUE fails, we constructed a bogus sequence above
|
||||
// and this test itself is buggy.
|
||||
ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment));
|
||||
|
||||
sequence[module->entry_computation()] =
|
||||
std::move(sequence_for_buffer_assigment);
|
||||
ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()]));
|
||||
|
||||
auto assignment =
|
||||
BufferAssigner::Run(
|
||||
@ -1716,55 +1715,6 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
|
||||
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
|
||||
}
|
||||
|
||||
// Test buffer assignment for while nodes with multiple uses.
|
||||
// TODO(b/37245345): Fix buffer assignment for this case.
|
||||
TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) {
|
||||
auto module = xla::MakeUnique<HloModule>(TestName());
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
auto input0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, data_shape_, "input0"));
|
||||
auto weights0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
|
||||
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
|
||||
auto output0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
|
||||
auto cond0 =
|
||||
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
|
||||
auto body0 =
|
||||
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
|
||||
|
||||
auto tuple0 = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({input0, weights0, output0}));
|
||||
auto while0 = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
|
||||
auto while1 = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0));
|
||||
|
||||
auto get0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
|
||||
auto get1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
RunCopyInsertion(module.get());
|
||||
|
||||
{
|
||||
FlattenCallGraph flatten;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
auto assignment = RunBufferAssignment(module.get());
|
||||
|
||||
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
|
||||
}
|
||||
|
||||
TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
|
||||
auto module = xla::MakeUnique<HloModule>(TestName());
|
||||
auto builder = HloComputation::Builder("entry");
|
||||
|
||||
@ -97,21 +97,32 @@ class Compiler {
|
||||
// Returns the ID of the platform that this compiler targets.
|
||||
virtual perftools::gputools::Platform::Id PlatformId() const = 0;
|
||||
|
||||
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
|
||||
// module.
|
||||
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* executor) = 0;
|
||||
|
||||
// Compiles the HLO module for execution on a device given by the executor,
|
||||
// and returns an executable object or an error status. Takes ownership of the
|
||||
// HLO module and is free to transform it.
|
||||
// and returns an executable object or an error status. No HLO passes are
|
||||
// applied to module. Generally a module should be passed through RunHloPasses
|
||||
// prior to calling this method because the some HLO passes are required for
|
||||
// correctness. Takes ownership of the HLO module and is free to transform it.
|
||||
//
|
||||
// The compiler may optionally specialize to the individual device
|
||||
// (not just type of device) indicated by the executor.
|
||||
//
|
||||
// Use the overload below to compile computations that run in parallel.
|
||||
virtual StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* executor) = 0;
|
||||
|
||||
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||
// communicating data between the modules, and returns a corresponding
|
||||
// sequence of executable objects.
|
||||
//
|
||||
// TODO(b/68666782): Remove this method after adding support for multiple
|
||||
// modules to RunHloPasses and RunBackends.
|
||||
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> modules,
|
||||
std::vector<std::vector<perftools::gputools::StreamExecutor*>>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -25,12 +25,25 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass which inserts a copy of the root instruction (creating a new root)
|
||||
// if the root is or points-to any constant or parameter instruction.
|
||||
// If the root instruction is a Tuple, only tuple elements which point to
|
||||
// constant or parameter instructions will be copied.
|
||||
// Copy insertion is necessary because constant and parameter arrays have
|
||||
// different lifetimes than computation results.
|
||||
// Copy insertion is a legalization HLO pass which inserts copies (kCopy
|
||||
// instructions) to eliminate several kinds of problems in the HLO module.
|
||||
//
|
||||
// (1) Entry parameter or a constant live out of the entry computation. Entry
|
||||
// computation arguments and constants have different lifetimes than the
|
||||
// computation result and cannot share the same allocation. Parameters and
|
||||
// constants live out of non-entry computations do not need copies.
|
||||
//
|
||||
// (2) Different values which are simultaneously live and which must be held
|
||||
// in the same buffer. This can occur in while bodies. Specifically, the
|
||||
// while loop state (the arguments to the while instruction) is updated
|
||||
// in-place and the update may clobber the value from the previous
|
||||
// iteration before the previous value is dead. Computations called from
|
||||
// kCall instructions do not need such copies because kCall has no update
|
||||
// in-place semantics.
|
||||
//
|
||||
// (3) The buffer set of the root instruction of the entry computation must be
|
||||
// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
|
||||
// InstructionAliasSet::IsDistinct return true.
|
||||
class CopyInsertion : public HloPassInterface {
|
||||
public:
|
||||
tensorflow::StringPiece name() const override { return "copy-insertion"; }
|
||||
@ -39,14 +52,16 @@ class CopyInsertion : public HloPassInterface {
|
||||
// (copies were inserted).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
protected:
|
||||
// Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
|
||||
// duplicate copies.
|
||||
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
|
||||
|
||||
// A map containing all copies inserted during the copy insertion pass. The
|
||||
// key is the copied instruction and the value is the copy.
|
||||
tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_;
|
||||
// The CPU and GPU backend need additional copies added due to deficiencies in
|
||||
// buffer assignment. Specifically, copies are needed for constants live-out
|
||||
// of computations, and for values which are live-in and live-out of the same
|
||||
// computation. These copies are needed because buffer-assignment uses a
|
||||
// computation-scoped analyis (TuplePointsToAnalysis) and has limited
|
||||
// visibility across computation boundaries. This method adds these necessary
|
||||
// copies. Returns whether the module was modified.
|
||||
//
|
||||
// TODO(b/62548313): Remove this when buffer assignment is module-scoped.
|
||||
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -79,6 +79,7 @@ cc_library(
|
||||
deps = [
|
||||
":compiler_functor",
|
||||
":conv_canonicalization",
|
||||
":cpu_copy_insertion",
|
||||
":cpu_executable",
|
||||
":cpu_instruction_fusion",
|
||||
":cpu_options",
|
||||
@ -103,7 +104,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:buffer_liveness",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
@ -273,6 +273,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:ops",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:code_gen",
|
||||
"@llvm//:core",
|
||||
"@llvm//:support",
|
||||
"@llvm//:target",
|
||||
@ -750,6 +751,38 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_copy_insertion",
|
||||
srcs = ["cpu_copy_insertion.cc"],
|
||||
hdrs = ["cpu_copy_insertion.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cpu_copy_insertion_test",
|
||||
srcs = ["cpu_copy_insertion_test.cc"],
|
||||
deps = [
|
||||
":cpu_copy_insertion",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
||||
@ -53,7 +53,7 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
|
||||
// kernel and output.
|
||||
//
|
||||
// For simplicity, as a first step, we reshape the input and filter to
|
||||
// NHWC and HWIO order, respectively. This may lose precision but not
|
||||
// NHWC and HWIO order, respectively. This may lose precision but won't
|
||||
// break the soundness.
|
||||
HloInstruction* input = hlo->mutable_operand(0);
|
||||
|
||||
@ -98,14 +98,18 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
|
||||
HloInstruction::CreateTranspose(new_kernel_shape, kernel,
|
||||
new_kernel_dim_order));
|
||||
|
||||
std::vector<int64> new_output_dim_order(num_dims);
|
||||
std::vector<int64> new_conv_dims(num_dims);
|
||||
auto output_batch_dim = dnums.output_batch_dimension();
|
||||
auto output_feature_dim = dnums.output_feature_dimension();
|
||||
new_output_dim_order[0] = output_batch_dim;
|
||||
new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim);
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
new_output_dim_order[i + 1] = dnums.spatial_dimensions(i);
|
||||
new_conv_dims[i + 1] =
|
||||
hlo->shape().dimensions(dnums.spatial_dimensions(i));
|
||||
}
|
||||
new_output_dim_order[num_dims - 1] = output_feature_dim;
|
||||
new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim);
|
||||
Shape new_conv_shape =
|
||||
ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims);
|
||||
@ -129,14 +133,11 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
|
||||
HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
|
||||
hlo->window(), new_dnums));
|
||||
|
||||
// kConvolution inherits the dimension mapping of its input, so we need to
|
||||
// reshape the output back to the shape of the original convolution. This
|
||||
// is done by apply the inverse permutation of the collapsing order of the
|
||||
// input reshape.
|
||||
// Reshape the output back to the shape of the original convolution.
|
||||
TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
|
||||
hlo, HloInstruction::CreateTranspose(
|
||||
hlo->shape(), new_conv,
|
||||
InversePermutation(new_input_dim_order))));
|
||||
InversePermutation(new_output_dim_order))));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,9 +46,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
@ -332,15 +332,16 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
|
||||
// (and sometime after) copy insertion, to avoid dead code from interfering
|
||||
// with the rewrites.
|
||||
pipeline.AddPass<HloDCE>();
|
||||
pipeline.AddPass<CopyInsertion>();
|
||||
pipeline.AddPass<FlattenCallGraph>();
|
||||
pipeline.AddPass<CpuCopyInsertion>();
|
||||
if (options::CpuParallelBackendRequested(module->config())) {
|
||||
// Re-run the outlining, in case any copies were inserted into the entry
|
||||
// computation.
|
||||
pipeline.AddPass<ParallelizationPreparation>(max_parallelism,
|
||||
ShapeSizeBytesFunction());
|
||||
pipeline.AddPass<CpuCopyInsertion>();
|
||||
}
|
||||
pipeline.AddPass<HloDCE>();
|
||||
pipeline.AddPass<FlattenCallGraph>();
|
||||
return pipeline.Run(module).status();
|
||||
}
|
||||
|
||||
@ -426,11 +427,25 @@ Status InitializeModuleHooks(
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
|
||||
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* /*stream_exec*/) {
|
||||
VLOG(2) << "Before optimization:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false));
|
||||
|
||||
VLOG(2) << "After optimization:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) {
|
||||
const string timer_message =
|
||||
"Compiling [" + module->name() + "] for CPU using JIT";
|
||||
ScopedLoggingTimer compiling_timer(timer_message, 1);
|
||||
XLA_SCOPED_LOGGING_TIMER(timer_message);
|
||||
|
||||
VLOG(1) << "Compiling: " << module->name();
|
||||
TF_RET_CHECK(stream_exec != nullptr);
|
||||
@ -458,14 +473,6 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
llvm_module->setDataLayout(jit->data_layout());
|
||||
llvm_module->setTargetTriple(jit->target_triple().getTriple());
|
||||
|
||||
VLOG(2) << "Before optimization:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false));
|
||||
|
||||
VLOG(2) << "After optimization:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
HloComputation* computation = module->entry_computation();
|
||||
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
|
||||
if (module->config().hlo_profiling_enabled()) {
|
||||
@ -537,11 +544,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
parallel_computations.emplace(to_apply, instruction);
|
||||
}
|
||||
|
||||
size_t entry_computation_profile_idx = hlo_to_profile_idx.size();
|
||||
IrEmitter ir_emitter(
|
||||
*module, *assignment, llvm_module.get(), std::move(hlo_to_profile_idx),
|
||||
/*entry_computation_profile_idx=*/entry_computation_profile_idx,
|
||||
jit->target_machine(), jit->external_constant_pool());
|
||||
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
|
||||
hlo_to_profile_idx, hlo_to_profile_idx.size(),
|
||||
jit->target_machine(), jit->external_constant_pool());
|
||||
|
||||
std::unique_ptr<HloInstructionMap<string>> function_names(
|
||||
new HloInstructionMap<string>());
|
||||
@ -619,11 +624,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
// before the entry computation. The order of computations returned from
|
||||
// GetEmbeddedComputations guarantees that a called computation occurs
|
||||
// before a caller computation.
|
||||
size_t entry_computation_profile_idx = hlo_to_profile_idx.size();
|
||||
IrEmitter ir_emitter(
|
||||
*module, *assignment, llvm_module.get(), std::move(hlo_to_profile_idx),
|
||||
/*entry_computation_profile_idx=*/entry_computation_profile_idx,
|
||||
jit->target_machine(), jit->external_constant_pool());
|
||||
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
|
||||
hlo_to_profile_idx, hlo_to_profile_idx.size(),
|
||||
jit->target_machine(), jit->external_constant_pool());
|
||||
|
||||
for (auto embedded_computation :
|
||||
computation->MakeEmbeddedComputationsList()) {
|
||||
|
||||
@ -116,7 +116,11 @@ class CpuCompiler : public LLVMCompiler {
|
||||
// stream_execs)
|
||||
using LLVMCompiler::Compile;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
|
||||
|
||||
43
tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc
Normal file
43
tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> CpuCopyInsertion::Run(HloModule* module) {
|
||||
CopyInsertion generic_copy_insertion;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool generic_changed, generic_copy_insertion.Run(module));
|
||||
|
||||
// The CPU backend needs additional copies added due to deficiencies in
|
||||
// buffer assignment.
|
||||
TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed,
|
||||
CopyInsertion::AddCopiesForBufferAssignment(module));
|
||||
|
||||
return generic_changed || buffer_assignment_changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
42
tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
Normal file
42
tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Besides the modifications made by the generic xla::CopyInsertion, this
|
||||
// CPU-specific copy insertion pass also adds copies to values live out of
|
||||
// computations satisfying certain conditions (defined by constant or parameter,
|
||||
// etc). This is necessary because of deficiencies of buffer
|
||||
// assignment. Specifically, buffer assignment is computation-scoped and does
|
||||
// not recognized aliasing between arguments and outputs of computations.
|
||||
//
|
||||
// TODO(b/62548313): Remove this when buffer assignment is smarter
|
||||
// (module-scoped).
|
||||
class CpuCopyInsertion : public HloPassInterface {
|
||||
public:
|
||||
tensorflow::StringPiece name() const override { return "copy-insertion"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
|
||||
139
tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
Normal file
139
tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
int64 CountCopies(const HloComputation& computation) {
|
||||
int64 count = 0;
|
||||
for (const auto& instruction : computation.instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
int64 CountCopies(const HloModule& module) {
|
||||
int64 count = 0;
|
||||
for (const auto& computation : module.computations()) {
|
||||
count += CountCopies(*computation);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
class CpuCopyInsertionTest : public HloTestBase {
|
||||
protected:
|
||||
void InsertCopies(HloModule* module) {
|
||||
CpuCopyInsertion copy_insertion;
|
||||
ASSERT_IS_OK(copy_insertion.Run(module).status());
|
||||
}
|
||||
|
||||
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
|
||||
};
|
||||
|
||||
TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
|
||||
// Test a while body and condition which are each simply a constant (root of
|
||||
// computation is a constant). Each constant should be copied.
|
||||
auto module = CreateNewModule();
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param_0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
|
||||
|
||||
auto body_builder = HloComputation::Builder("body");
|
||||
body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
body_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
|
||||
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloComputation* condition =
|
||||
module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto xla_while = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 3);
|
||||
|
||||
EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
|
||||
EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
|
||||
EXPECT_THAT(condition->root_instruction(), op::Copy(op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(CpuCopyInsertionTest, TupleCall) {
|
||||
// Test a kCall instruction which calls a computation which produces a three
|
||||
// element tuple: one is a constant, one is a parameter, and one is produced
|
||||
// in the computation. The constant and parameter should be copied.
|
||||
auto module = CreateNewModule();
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_});
|
||||
|
||||
auto sub_builder = HloComputation::Builder("subcomputation");
|
||||
auto sub_param = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
auto constant = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
|
||||
auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape_, HloOpcode::kAdd, sub_param, constant));
|
||||
sub_builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({sub_param, constant, add}));
|
||||
HloComputation* subcomputation =
|
||||
module->AddEmbeddedComputation(sub_builder.Build());
|
||||
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateCall(tuple_shape, {param}, subcomputation));
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*subcomputation), 2);
|
||||
EXPECT_THAT(subcomputation->root_instruction(),
|
||||
op::Tuple(op::Copy(op::Parameter()), op::Copy(op::Constant()),
|
||||
op::Add()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
@ -147,8 +147,9 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
std::vector<se::DeviceMemoryBase> argument_buffers;
|
||||
for (int i = 0; i < arguments.size(); ++i) {
|
||||
argument_buffers.push_back(arguments[i]->buffer(/*index=*/{}));
|
||||
argument_buffers.reserve(arguments.size());
|
||||
for (const auto* argument : arguments) {
|
||||
argument_buffers.push_back(argument->buffer(/*index=*/{}));
|
||||
}
|
||||
return ExecuteComputeFunction(run_options, argument_buffers, buffers,
|
||||
hlo_execution_profile);
|
||||
|
||||
@ -26,14 +26,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
|
||||
#include "llvm/CodeGen/TargetRegisterInfo.h"
|
||||
#include "llvm/CodeGen/TargetSubtargetInfo.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/GlobalVariable.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/Target/TargetRegisterInfo.h"
|
||||
#include "llvm/Target/TargetSubtargetInfo.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
@ -795,7 +795,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
||||
// operand index is within the bounds. The unsigned comparison includes
|
||||
// checking whether the operand index >= 0.
|
||||
llvm_ir::IrArray::Index operand_index(source_index.size());
|
||||
llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
|
||||
llvm::Value* in_bounds_condition = ir_builder_.getTrue();
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
llvm::Value* strided_index = ir_builder_.CreateNSWMul(
|
||||
source_index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
|
||||
@ -1140,7 +1140,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
||||
return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0));
|
||||
};
|
||||
|
||||
llvm::Value* in_bounds_condition = nullptr;
|
||||
llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
llvm::ConstantInt* input_bound =
|
||||
ir_builder_.getInt64(window_util::DilatedBound(
|
||||
@ -1153,9 +1153,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
||||
llvm::Value* dim_ok =
|
||||
ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole);
|
||||
in_bounds_condition =
|
||||
in_bounds_condition
|
||||
? ir_builder_.CreateAnd(in_bounds_condition, dim_ok)
|
||||
: dim_ok;
|
||||
ir_builder_.CreateAnd(in_bounds_condition, dim_ok);
|
||||
}
|
||||
|
||||
// Now we need to map the dilated base coordinates back to the actual
|
||||
|
||||
@ -86,6 +86,9 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleConvert(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseUnary(hlo);
|
||||
}
|
||||
virtual Status HandleBitcastConvert(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseUnary(hlo);
|
||||
}
|
||||
virtual Status HandleCopy(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseUnary(hlo);
|
||||
}
|
||||
@ -208,6 +211,7 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
|
||||
|
||||
virtual Status HandlePad(HloInstructionPtr hlo) = 0;
|
||||
|
||||
|
||||
@ -167,6 +167,9 @@ class DfsHloVisitorWithDefaultBase
|
||||
Status HandleWhile(HloInstructionPtr xla_while) override {
|
||||
return DefaultAction(xla_while);
|
||||
}
|
||||
Status HandleConditional(HloInstructionPtr conditional) override {
|
||||
return DefaultAction(conditional);
|
||||
}
|
||||
Status HandleRecv(HloInstructionPtr recv) override {
|
||||
return DefaultAction(recv);
|
||||
}
|
||||
|
||||
@ -110,6 +110,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
||||
PrimitiveType_Name(from_type).c_str(),
|
||||
PrimitiveType_Name(to_type).c_str());
|
||||
}
|
||||
case HloOpcode::kBitcastConvert: {
|
||||
PrimitiveType from_type = op->operand(0)->shape().element_type();
|
||||
PrimitiveType to_type = op->shape().element_type();
|
||||
CHECK(primitive_util::IsIntegralType(from_type));
|
||||
if (from_type == to_type) {
|
||||
return operand_value;
|
||||
}
|
||||
if (primitive_util::BitWidth(from_type) ==
|
||||
primitive_util::BitWidth(to_type)) {
|
||||
return ir_builder_->CreateBitCast(
|
||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||
}
|
||||
return InvalidArgument(
|
||||
"bitcast conversion from primitive type %s to %s with unequal "
|
||||
"bit-widths (%u versus %u) ",
|
||||
PrimitiveType_Name(from_type).c_str(),
|
||||
PrimitiveType_Name(to_type).c_str(),
|
||||
primitive_util::BitWidth(from_type),
|
||||
primitive_util::BitWidth(to_type));
|
||||
}
|
||||
case HloOpcode::kAbs: {
|
||||
bool is_signed =
|
||||
primitive_util::IsSignedIntegralType(op->shape().element_type());
|
||||
@ -203,6 +223,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
||||
PrimitiveType_Name(from_type).c_str(),
|
||||
PrimitiveType_Name(to_type).c_str());
|
||||
}
|
||||
case HloOpcode::kBitcastConvert: {
|
||||
PrimitiveType from_type = op->operand(0)->shape().element_type();
|
||||
PrimitiveType to_type = op->shape().element_type();
|
||||
CHECK(primitive_util::IsFloatingPointType(from_type));
|
||||
if (from_type == to_type) {
|
||||
return operand_value;
|
||||
}
|
||||
if (primitive_util::BitWidth(from_type) ==
|
||||
primitive_util::BitWidth(to_type)) {
|
||||
return ir_builder_->CreateBitCast(
|
||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||
}
|
||||
return InvalidArgument(
|
||||
"bitcast conversion from primitive type %s to %s with unequal "
|
||||
"bit-widths (%u versus %u) ",
|
||||
PrimitiveType_Name(from_type).c_str(),
|
||||
PrimitiveType_Name(to_type).c_str(),
|
||||
primitive_util::BitWidth(from_type),
|
||||
primitive_util::BitWidth(to_type));
|
||||
}
|
||||
case HloOpcode::kExp:
|
||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value},
|
||||
{operand_value->getType()},
|
||||
@ -1073,6 +1113,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kBitcastConvert:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kExp:
|
||||
@ -1081,11 +1122,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kTanh:
|
||||
case HloOpcode::kNot:
|
||||
return [this, hlo, &operand_to_generator](
|
||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
||||
@ -1094,6 +1135,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
return EmitUnaryOp(hlo, operand_value);
|
||||
};
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kDivide:
|
||||
@ -1106,14 +1148,13 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical:
|
||||
case HloOpcode::kSubtract:
|
||||
return [this, hlo, &operand_to_generator](
|
||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
const HloInstruction* lhs = hlo->operand(0);
|
||||
|
||||
@ -343,15 +343,16 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "copy_insertion",
|
||||
srcs = ["copy_insertion.cc"],
|
||||
hdrs = ["copy_insertion.h"],
|
||||
name = "gpu_copy_insertion",
|
||||
srcs = ["gpu_copy_insertion.cc"],
|
||||
hdrs = ["gpu_copy_insertion.h"],
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla/service:call_graph",
|
||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:logical_buffer",
|
||||
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
@ -427,8 +428,8 @@ cc_library(
|
||||
hdrs = ["gpu_compiler.h"],
|
||||
deps = [
|
||||
":convolution_folding",
|
||||
":copy_insertion",
|
||||
":fusion_merger",
|
||||
":gpu_copy_insertion",
|
||||
":gpu_executable",
|
||||
":hlo_schedule",
|
||||
":instruction_fusion",
|
||||
@ -574,11 +575,14 @@ tf_cc_test(
|
||||
deps = [
|
||||
":instruction_fusion",
|
||||
":while_transformer",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -258,22 +258,19 @@ tensorflow::Status ConvolutionThunk::Convolve(
|
||||
}
|
||||
|
||||
std::vector<AlgorithmDesc> ConvolutionThunk::GetAlgorithms(
|
||||
se::StreamExecutor* stream_exec) const {
|
||||
bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const {
|
||||
std::vector<AlgorithmDesc> algorithms;
|
||||
// TODO(yangzihao): Currently disable the use of winograd nonfused in XLA
|
||||
// by default. Should send in conv parameters and enable it when
|
||||
// ShouldIncludeWinogradNonfusedAlgo() returns true.
|
||||
switch (convolution_kind_) {
|
||||
case ConvolutionKind::kBackwardFilter:
|
||||
CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
|
||||
/*with_winograd_nonfused=*/false, &algorithms));
|
||||
with_winograd_nonfused, &algorithms));
|
||||
break;
|
||||
case ConvolutionKind::kBackwardInput:
|
||||
CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
|
||||
/*with_winograd_nonfused=*/false, &algorithms));
|
||||
with_winograd_nonfused, &algorithms));
|
||||
break;
|
||||
case ConvolutionKind::kForward:
|
||||
CHECK(stream_exec->GetConvolveAlgorithms(/*with_winograd_nonfused=*/false,
|
||||
CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused,
|
||||
&algorithms));
|
||||
break;
|
||||
}
|
||||
@ -287,6 +284,26 @@ static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) {
|
||||
return tensorflow::strings::StrCat(algo.algo_id());
|
||||
}
|
||||
|
||||
// Determines whether we can safely perform a winograd non-fused convolution for
|
||||
// the given input and output descriptors. This works around b/68264959, an
|
||||
// integer overflow in cuDNNv5 and cuDNNv6.
|
||||
static bool ShouldIncludeWinogradNonfusedAlgo(
|
||||
const BatchDescriptor& input_descriptor,
|
||||
const BatchDescriptor& output_descriptor) {
|
||||
int64 batch = input_descriptor.count();
|
||||
int64 in_depths = input_descriptor.feature_map_count();
|
||||
int64 in_rows = input_descriptor.height();
|
||||
int64 in_cols = input_descriptor.width();
|
||||
int64 out_depths = output_descriptor.feature_map_count();
|
||||
|
||||
int64 total_size = 16 * std::ceil(batch / 16.0) *
|
||||
std::max(in_depths, out_depths) * in_cols * in_rows *
|
||||
sizeof(float);
|
||||
int64 threshold = 1L << 31;
|
||||
|
||||
return total_size < threshold;
|
||||
}
|
||||
|
||||
tensorflow::Status ConvolutionThunk::ConvolveWithTune(
|
||||
const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data,
|
||||
const FilterDescriptor& filter_descriptor,
|
||||
@ -303,9 +320,13 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
|
||||
"ConvolutionThunk: "
|
||||
<< this;
|
||||
|
||||
bool with_winograd_nonfused =
|
||||
ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor);
|
||||
|
||||
se::dnn::ProfileResult best_result;
|
||||
se::dnn::ProfileResult best_result_without_scratch;
|
||||
std::vector<AlgorithmDesc> algorithms = GetAlgorithms(stream->parent());
|
||||
std::vector<AlgorithmDesc> algorithms =
|
||||
GetAlgorithms(with_winograd_nonfused, stream->parent());
|
||||
for (auto algorithm : algorithms) {
|
||||
ConvolveScratchAllocator scratch_allocator(
|
||||
buffer_allocations.device_ordinal(),
|
||||
|
||||
@ -116,6 +116,7 @@ class ConvolutionThunk : public Thunk {
|
||||
|
||||
// Returns the convolve algorithms that can be used for this ConvolutionThunk.
|
||||
std::vector<perftools::gputools::dnn::AlgorithmDesc> GetAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
perftools::gputools::StreamExecutor* stream_exec) const;
|
||||
|
||||
// Fastest cuDNN convolution algorithm for this thunk learned from
|
||||
|
||||
@ -1,71 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(bool changed, CopyInsertion::Run(module));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(module));
|
||||
|
||||
// Make sure all operands of a library call are in memory instead of constants
|
||||
// in IR. The top-level (index {}) of the points-to set of each operand
|
||||
// indicates the source(s) of the array buffer. If any of these are constant,
|
||||
// then add a copy to materialize the array.
|
||||
HloComputation* computation = module->entry_computation();
|
||||
for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
|
||||
if (ImplementedAsLibraryCall(*hlo)) {
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
HloInstruction* operand = hlo->mutable_operand(i);
|
||||
const PointsToSet& points_to =
|
||||
points_to_analysis->GetPointsToSet(operand);
|
||||
const auto& element = points_to.element(/*index=*/{});
|
||||
if (std::any_of(element.begin(), element.end(),
|
||||
[](const LogicalBuffer* buffer_source) {
|
||||
return buffer_source->instruction()->opcode() ==
|
||||
HloOpcode::kConstant;
|
||||
})) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * copy,
|
||||
CopyInsertion::FindOrInsertCopy(operand));
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
@ -33,8 +33,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
@ -126,7 +126,7 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) {
|
||||
|
||||
// Runs optimization passes on the given HLO module.
|
||||
tensorflow::Status OptimizeHloModule(
|
||||
HloModule* hlo_module, const se::DeviceDescription& device_desc,
|
||||
HloModule* hlo_module,
|
||||
const HloCostAnalysis::ShapeSizeFunction& shape_size_function) {
|
||||
{
|
||||
HloPassPipeline pipeline("optimization");
|
||||
@ -224,9 +224,8 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
|
||||
// (and sometime after) copy insertion, to avoid dead code from interfering
|
||||
// with the rewrites.
|
||||
pipeline.AddPass<HloDCE>();
|
||||
pipeline.AddPass<GpuCopyInsertion>();
|
||||
pipeline.AddPass<HloDCE>();
|
||||
pipeline.AddPass<FlattenCallGraph>();
|
||||
pipeline.AddPass<GpuCopyInsertion>();
|
||||
return pipeline.Run(hlo_module).status();
|
||||
}
|
||||
|
||||
@ -297,19 +296,23 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
|
||||
GpuCompiler::GpuCompiler()
|
||||
: pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
|
||||
StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/) {
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
|
||||
Tracing::TraceMe annotation("HLO Transforms", module->name(),
|
||||
/*is_expensive=*/true);
|
||||
TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), ShapeSizeBytesFunction()));
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
|
||||
|
||||
TF_RET_CHECK(stream_exec != nullptr);
|
||||
|
||||
{
|
||||
Tracing::TraceMe annotation("HLO Transforms", module->name(),
|
||||
/*is_expensive=*/true);
|
||||
TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(),
|
||||
stream_exec->GetDeviceDescription(),
|
||||
ShapeSizeBytesFunction()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
|
||||
|
||||
llvm::LLVMContext llvm_context;
|
||||
std::string buffer;
|
||||
@ -362,8 +365,11 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
|
||||
HloComputation* entry_computation = module->entry_computation();
|
||||
IrEmitterUnnested ir_emitter(module->config(), entry_computation,
|
||||
&ir_emitter_context);
|
||||
TF_RETURN_IF_ERROR(
|
||||
entry_computation->root_instruction()->Accept(&ir_emitter));
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
|
||||
TF_RETURN_IF_ERROR(
|
||||
entry_computation->root_instruction()->Accept(&ir_emitter));
|
||||
}
|
||||
|
||||
if (user_pre_optimization_hook_) {
|
||||
TF_CHECK_OK(user_pre_optimization_hook_(llvm_module));
|
||||
@ -412,9 +418,12 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
|
||||
cc_minor = 0;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(string ptx,
|
||||
CompileToPtx(&llvm_module, {cc_major, cc_minor},
|
||||
module->config(), libdevice_dir));
|
||||
string ptx;
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - CompileToPtx");
|
||||
TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor},
|
||||
module->config(), libdevice_dir));
|
||||
}
|
||||
|
||||
if (!ir_dump_directory.empty()) {
|
||||
TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory(
|
||||
@ -470,6 +479,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
|
||||
std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx,
|
||||
int cc_major,
|
||||
int cc_minor) {
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult");
|
||||
Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true);
|
||||
bool inserted;
|
||||
decltype(compilation_cache_.begin()) iter;
|
||||
|
||||
@ -49,7 +49,11 @@ class GpuCompiler : public LLVMCompiler {
|
||||
// stream_execs)
|
||||
using LLVMCompiler::Compile;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
|
||||
|
||||
112
tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
Normal file
112
tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace gpu {
|
||||
|
||||
StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
|
||||
HloInstruction* hlo) {
|
||||
HloInstruction*& copy = inserted_copies_[hlo];
|
||||
if (copy == nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
|
||||
}
|
||||
return copy;
|
||||
}
|
||||
|
||||
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
|
||||
CopyInsertion generic_copy_insertion;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
|
||||
HloDataflowAnalysis::Run(module));
|
||||
|
||||
// Make sure all operands of a library call are in memory instead of constants
|
||||
// in IR.
|
||||
for (HloInstruction* hlo :
|
||||
module->entry_computation()->MakeInstructionPostOrder()) {
|
||||
if (ImplementedAsLibraryCall(*hlo)) {
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
HloInstruction* operand = hlo->mutable_operand(i);
|
||||
TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
|
||||
const auto& values = dataflow->GetValueSet(operand).values();
|
||||
if (std::any_of(values.begin(), values.end(),
|
||||
[](const HloValue* value) {
|
||||
return value->defining_instruction()->opcode() ==
|
||||
HloOpcode::kConstant;
|
||||
})) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand));
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Init values of a while node cannot be constants. Insert copies for any
|
||||
// constants found at the operand of a while.
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> copied_constants;
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() != HloOpcode::kWhile) {
|
||||
continue;
|
||||
}
|
||||
for (auto& pair :
|
||||
dataflow->GetInstructionValueSet(instruction->operand(0))) {
|
||||
const HloValueSet& value_set = pair.second;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
if (value->defining_instruction()->opcode() ==
|
||||
HloOpcode::kConstant &&
|
||||
!ContainsKey(copied_constants, value->defining_instruction())) {
|
||||
HloInstruction* constant = value->defining_instruction();
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * copy,
|
||||
FindOrInsertCopy(constant));
|
||||
TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
|
||||
copied_constants.insert(constant);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The GPU backend needs additional copies added due to deficiencies in
|
||||
// buffer assignment.
|
||||
TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed,
|
||||
CopyInsertion::AddCopiesForBufferAssignment(module));
|
||||
|
||||
return changed || buffer_assignment_changed;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -25,12 +25,23 @@ namespace gpu {
|
||||
// Besides the modifications made by the generic xla::CopyInsertion, this
|
||||
// GPU-specific copy insertion also materializes operands of library calls by
|
||||
// inserting kCopy instructions.
|
||||
class GpuCopyInsertion : public CopyInsertion {
|
||||
class GpuCopyInsertion : public HloPassInterface {
|
||||
public:
|
||||
tensorflow::StringPiece name() const override { return "copy-insertion"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
protected:
|
||||
// Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
|
||||
// duplicate copies.
|
||||
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
|
||||
|
||||
// A map containing all copies inserted to materialize operands of library
|
||||
// calls. The key is the copied instruction and the value is the copy.
|
||||
tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_
|
||||
@ -492,9 +492,8 @@ StatusOr<string> CompileToPtx(llvm::Module* module,
|
||||
tensorflow::port::Tracing::TraceMe annotation(
|
||||
"Compiling IR", llvm_ir::AsString(module->getName()),
|
||||
/*is_expensive=*/true);
|
||||
ScopedLoggingTimer compilation_timer(
|
||||
"Compile module " + llvm_ir::AsString(module->getName()),
|
||||
/*vlog_level=*/2);
|
||||
XLA_SCOPED_LOGGING_TIMER("Compile module " +
|
||||
llvm_ir::AsString(module->getName()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config,
|
||||
libdevice_dir_path));
|
||||
|
||||
@ -17,9 +17,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
@ -33,8 +36,6 @@ class WhileTransformerTest : public HloTestBase {
|
||||
: module_(CreateNewModule()),
|
||||
induction_variable_shape_(ShapeUtil::MakeShape(S32, {})),
|
||||
data_shape_(ShapeUtil::MakeShape(F32, {8})),
|
||||
loop_state_shape_(ShapeUtil::MakeTupleShape(
|
||||
{induction_variable_shape_, data_shape_})),
|
||||
condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {}
|
||||
|
||||
std::unique_ptr<HloComputation> BuildConditionComputation(
|
||||
@ -42,8 +43,8 @@ class WhileTransformerTest : public HloTestBase {
|
||||
auto builder = HloComputation::Builder(TestName() + ".Condition");
|
||||
auto limit_const = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(limit)));
|
||||
auto loop_state = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
|
||||
auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, GetLoopStateShape(tuple_index), "loop_state"));
|
||||
auto induction_variable =
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
limit_const->shape(), loop_state, tuple_index));
|
||||
@ -58,8 +59,8 @@ class WhileTransformerTest : public HloTestBase {
|
||||
const int64 increment) {
|
||||
auto builder = HloComputation::Builder(TestName() + ".Body");
|
||||
// Create param instruction to access loop state.
|
||||
auto loop_state = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
|
||||
auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, GetLoopStateShape(ind_var_tuple_index), "loop_state"));
|
||||
// Update the induction variable GTE(ind_var_tuple_index).
|
||||
auto induction_variable =
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
@ -73,7 +74,7 @@ class WhileTransformerTest : public HloTestBase {
|
||||
data_shape_, loop_state, data_tuple_index));
|
||||
// Use 'induction_variable' in computation with no path to output tuple.
|
||||
auto update = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
|
||||
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {}));
|
||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape_, HloOpcode::kAdd, data, update));
|
||||
// Create output Tuple.
|
||||
@ -98,8 +99,9 @@ class WhileTransformerTest : public HloTestBase {
|
||||
HloInstruction::CreateTuple({induction_var_init, data_init}))
|
||||
: builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({data_init, induction_var_init}));
|
||||
auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
loop_state_shape_, condition, body, loop_state_init));
|
||||
auto while_hlo = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index),
|
||||
condition, body, loop_state_init));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
return while_hlo;
|
||||
}
|
||||
@ -115,18 +117,34 @@ class WhileTransformerTest : public HloTestBase {
|
||||
}
|
||||
|
||||
void RunCopyInsertionPass() {
|
||||
HloVerifier verifier([](const Shape& shape) {
|
||||
return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
|
||||
});
|
||||
TF_ASSERT_OK(verifier.Run(module_.get()).status());
|
||||
CopyInsertion copy_insertion;
|
||||
EXPECT_IS_OK(copy_insertion.Run(module_.get()).status());
|
||||
TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
|
||||
}
|
||||
|
||||
Shape GetLoopStateShape(const int64 ind_var_tuple_index) {
|
||||
if (ind_var_tuple_index == 0) {
|
||||
return ShapeUtil::MakeTupleShape(
|
||||
{induction_variable_shape_, data_shape_});
|
||||
} else {
|
||||
return ShapeUtil::MakeTupleShape(
|
||||
{data_shape_, induction_variable_shape_});
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> module_;
|
||||
Shape induction_variable_shape_;
|
||||
Shape data_shape_;
|
||||
Shape loop_state_shape_;
|
||||
Shape condition_result_shape_;
|
||||
};
|
||||
|
||||
TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
|
||||
// TODO(b/68830972): The while transformer is far too fragile. It patterns
|
||||
// matches the exact expressions of opcodes. Re-enable when transformation is
|
||||
// more general
|
||||
TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) {
|
||||
// Build computation with induction variable at tuple element 0.
|
||||
auto condition =
|
||||
module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
|
||||
@ -137,13 +155,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
|
||||
RunCopyInsertionPass();
|
||||
// Run WhileTransformer.
|
||||
auto result = gpu::CanTransformWhileToFor(while_hlo);
|
||||
ASSERT_TRUE(result.ok());
|
||||
TF_ASSERT_OK(result.status());
|
||||
// Check results.
|
||||
EXPECT_THAT(result.ConsumeValueOrDie(),
|
||||
Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
|
||||
}
|
||||
|
||||
TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
|
||||
// TODO(b/68830972): The while transformer is far too fragile. It patterns
|
||||
// matches the exact expressions of opcodes. Re-enable when transformation is
|
||||
// more general
|
||||
TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) {
|
||||
// Build computation with induction variable at tuple element 1.
|
||||
auto condition =
|
||||
module_->AddEmbeddedComputation(BuildConditionComputation(1, 10));
|
||||
@ -154,13 +175,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
|
||||
RunCopyInsertionPass();
|
||||
// Run WhileTransformer.
|
||||
auto result = gpu::CanTransformWhileToFor(while_hlo);
|
||||
ASSERT_TRUE(result.ok());
|
||||
TF_ASSERT_OK(result.status());
|
||||
// Check results.
|
||||
EXPECT_THAT(result.ConsumeValueOrDie(),
|
||||
Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
|
||||
}
|
||||
|
||||
TEST_F(WhileTransformerTest, InvalidLoopLimit) {
|
||||
// TODO(b/68830972): The while transformer is far too fragile. It patterns
|
||||
// matches the exact expressions of opcodes. Re-enable when transformation is
|
||||
// more general
|
||||
TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) {
|
||||
// Build computation with invalid loop limit.
|
||||
auto condition =
|
||||
module_->AddEmbeddedComputation(BuildConditionComputation(0, 5));
|
||||
@ -176,7 +200,10 @@ TEST_F(WhileTransformerTest, InvalidLoopLimit) {
|
||||
HasSubstr("Loop start must be less than loop limit."));
|
||||
}
|
||||
|
||||
TEST_F(WhileTransformerTest, InvalidLoopIncrement) {
|
||||
// TODO(b/68830972): The while transformer is far too fragile. It patterns
|
||||
// matches the exact expressions of opcodes. Re-enable when transformation is
|
||||
// more general
|
||||
TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) {
|
||||
// Build computation with invalid loop increment.
|
||||
auto condition =
|
||||
module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
|
||||
|
||||
@ -250,7 +250,3 @@ message HloProto {
|
||||
HloOrderingProto hlo_ordering = 2;
|
||||
BufferAssignmentProto buffer_assignment = 3;
|
||||
}
|
||||
|
||||
message HloProtos {
|
||||
repeated HloProto hlo_protos = 1;
|
||||
}
|
||||
|
||||
@ -144,8 +144,10 @@ class BufferValueMap {
|
||||
// Move the given value into the given buffer.
|
||||
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
|
||||
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
|
||||
buffers_.at(old_buffer_number).erase(&value);
|
||||
if (buffers_.at(old_buffer_number).empty()) {
|
||||
tensorflow::gtl::FlatSet<const HloValue*>& old_value_set =
|
||||
buffers_.at(old_buffer_number);
|
||||
old_value_set.erase(&value);
|
||||
if (old_value_set.empty()) {
|
||||
buffers_.erase(old_buffer_number);
|
||||
}
|
||||
|
||||
@ -175,7 +177,7 @@ class BufferValueMap {
|
||||
// Value is init of a while (use is while).
|
||||
std::vector<BufferNumber> aliased_buffers;
|
||||
for (const HloUse& use : value.uses()) {
|
||||
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
|
||||
VLOG(2) << "use of value " << value.ToShortString() << ": " << use;
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
// Determine the while value that this shares a buffer with.
|
||||
const HloValue& while_value =
|
||||
@ -411,7 +413,7 @@ string HloAliasAnalysis::ToString() const {
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
HloModule* module) {
|
||||
VLOG(1) << "HloAliasAnalysis::Run on module " << module->name();
|
||||
VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
|
||||
@ -444,7 +446,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
|
||||
TF_DCHECK_OK(alias_analysis->Verify());
|
||||
|
||||
XLA_VLOG_LINES(1, alias_analysis->ToString());
|
||||
XLA_VLOG_LINES(2, alias_analysis->ToString());
|
||||
return std::move(alias_analysis);
|
||||
}
|
||||
|
||||
|
||||
@ -407,16 +407,18 @@ HloComputationProto HloComputation::ToProto() const {
|
||||
/* static */ StatusOr<std::unique_ptr<HloComputation>>
|
||||
HloComputation::CreateFromProto(
|
||||
HloModule* module, const HloComputationProto& proto,
|
||||
tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map,
|
||||
const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
|
||||
const std::function<void(std::unique_ptr<HloComputation>)>&
|
||||
add_fused_computation,
|
||||
HloInstruction* fusion_instruction) {
|
||||
std::vector<std::unique_ptr<HloInstruction>> instructions;
|
||||
tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map;
|
||||
int64 parameter_count = 0;
|
||||
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloInstruction> instruction,
|
||||
HloInstruction::CreateFromProto(module, instruction_proto,
|
||||
instruction_map, computation_map));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
|
||||
HloInstruction::CreateFromProto(
|
||||
module, instruction_proto, instruction_map,
|
||||
computation_map, add_fused_computation));
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
parameter_count++;
|
||||
}
|
||||
|
||||
@ -152,12 +152,16 @@ class HloComputation {
|
||||
// computation_map: a map from computation name to HloComputation*. This map
|
||||
// must contain all computations which the newly constructed computation
|
||||
// calls.
|
||||
// fusion_instruction: if non-null then the newly created computation will be
|
||||
// constructed as a fused computation with this instruction as its fusion
|
||||
// parent.
|
||||
// add_fused_computation: A function to call to add a fused
|
||||
// computation. Used only when the instruction is a fusion instruction.
|
||||
// fusion_instruction: if non-null then the newly created computation will
|
||||
// be constructed as a fused computation with this instruction as its
|
||||
// fusion parent.
|
||||
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
|
||||
HloModule* module, const HloComputationProto& proto,
|
||||
tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map,
|
||||
const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
|
||||
const std::function<void(std::unique_ptr<HloComputation>)>&
|
||||
add_fused_computation,
|
||||
HloInstruction* fusion_instruction = nullptr);
|
||||
|
||||
// Gets the instructions in this computation.
|
||||
|
||||
@ -22,13 +22,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/bits.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
constexpr char HloCostAnalysis::kFlopsKey[];
|
||||
constexpr char HloCostAnalysis::kTranscendentalsKey[];
|
||||
constexpr char HloCostAnalysis::kBytesAccessedKey[];
|
||||
constexpr char HloCostAnalysis::kSecondsKey[];
|
||||
constexpr char HloCostAnalysis::kOptimalSecondsKey[];
|
||||
|
||||
HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
|
||||
: HloCostAnalysis(shape_size, {}) {}
|
||||
@ -60,16 +61,16 @@ Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
|
||||
if (current_should_compute_bottleneck_time_) {
|
||||
// Compute the time as the time of the bottleneck, i.e. the slowest property
|
||||
// given the per-second rate of each property.
|
||||
float max_seconds = 0.0f;
|
||||
float optimal_seconds = 0.0f;
|
||||
for (const auto& property : current_properties_) {
|
||||
if (property.first != kSecondsKey) {
|
||||
max_seconds = std::max(
|
||||
max_seconds,
|
||||
if (property.first != kOptimalSecondsKey) {
|
||||
optimal_seconds = std::max(
|
||||
optimal_seconds,
|
||||
property.second /
|
||||
GetProperty(property.first, per_second_rates_, INFINITY));
|
||||
}
|
||||
}
|
||||
current_properties_[kSecondsKey] = max_seconds;
|
||||
current_properties_[kOptimalSecondsKey] = optimal_seconds;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
|
||||
@ -480,6 +481,25 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
|
||||
// Compute the cost of the true and false computations and take the maximum
|
||||
// from those for each property.
|
||||
TF_ASSIGN_OR_RETURN(const Properties true_computation_properties,
|
||||
ProcessSubcomputation(conditional->true_computation()));
|
||||
TF_ASSIGN_OR_RETURN(const Properties false_computation_properties,
|
||||
ProcessSubcomputation(conditional->false_computation()));
|
||||
current_properties_ = true_computation_properties;
|
||||
for (const auto& property : false_computation_properties) {
|
||||
if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) {
|
||||
current_properties_[property.first] =
|
||||
std::max(current_properties_[property.first], property.second);
|
||||
}
|
||||
}
|
||||
current_should_compute_bottleneck_time_ = false;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -496,8 +516,8 @@ float HloCostAnalysis::bytes_accessed() const {
|
||||
return GetProperty(kBytesAccessedKey, properties_sum_);
|
||||
}
|
||||
|
||||
float HloCostAnalysis::seconds() const {
|
||||
return GetProperty(kSecondsKey, properties_sum_);
|
||||
float HloCostAnalysis::optimal_seconds() const {
|
||||
return GetProperty(kOptimalSecondsKey, properties_sum_);
|
||||
}
|
||||
|
||||
int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
|
||||
@ -512,8 +532,8 @@ int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
|
||||
return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
|
||||
}
|
||||
|
||||
float HloCostAnalysis::seconds(const HloInstruction& hlo) const {
|
||||
return GetPropertyForHlo(hlo, kSecondsKey, hlo_properties_);
|
||||
float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
|
||||
return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
|
||||
}
|
||||
|
||||
StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
|
||||
|
||||
@ -42,7 +42,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
static constexpr char kFlopsKey[] = "flops";
|
||||
static constexpr char kTranscendentalsKey[] = "transcendentals";
|
||||
static constexpr char kBytesAccessedKey[] = "bytes accessed";
|
||||
static constexpr char kSecondsKey[] = "seconds";
|
||||
static constexpr char kOptimalSecondsKey[] = "optimal_seconds";
|
||||
|
||||
// shape_size is a function which returns the size in bytes of the top-level
|
||||
// buffer of a shape.
|
||||
@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
Status HandleReshape(const HloInstruction* reshape) override;
|
||||
Status HandleTranspose(const HloInstruction* transpose) override;
|
||||
Status HandleWhile(const HloInstruction* xla_while) override;
|
||||
Status HandleConditional(const HloInstruction* conditional) override;
|
||||
Status FinishVisit(const HloInstruction* root) override;
|
||||
|
||||
Status Preprocess(const HloInstruction* hlo) override;
|
||||
@ -118,14 +119,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
float flop_count() const;
|
||||
float transcendental_count() const;
|
||||
float bytes_accessed() const;
|
||||
float seconds() const;
|
||||
float optimal_seconds() const;
|
||||
|
||||
// Returns the respective cost computed for a particular HLO instruction, or 0
|
||||
// if the HLO was not found to have a cost in the analysis.
|
||||
int64 flop_count(const HloInstruction& hlo) const;
|
||||
int64 transcendental_count(const HloInstruction& hlo) const;
|
||||
int64 bytes_accessed(const HloInstruction& hlo) const;
|
||||
float seconds(const HloInstruction& hlo) const;
|
||||
float optimal_seconds(const HloInstruction& hlo) const;
|
||||
|
||||
const Properties& properties() const { return properties_sum_; }
|
||||
const float property(const string& key) const {
|
||||
|
||||
@ -389,7 +389,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
|
||||
static_assert(bytes_accessed == 64, "");
|
||||
EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed);
|
||||
|
||||
EXPECT_EQ(fusion_analysis.seconds(), 1 << i);
|
||||
EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -37,6 +37,9 @@ namespace xla {
|
||||
StatusOr<bool> HloDCE::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
|
||||
VLOG(2) << "Before dce:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
for (auto* computation : module->MakeNonfusionComputations()) {
|
||||
std::unordered_set<HloInstruction*> live_instructions;
|
||||
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
|
||||
@ -58,6 +61,8 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
|
||||
}
|
||||
|
||||
for (HloInstruction* dead_root : dead_roots) {
|
||||
VLOG(1) << "Removing dead root " << dead_root->ToString()
|
||||
<< " and it's unused operands";
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation->RemoveInstructionAndUnusedOperands(dead_root));
|
||||
changed = true;
|
||||
@ -87,6 +92,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(2) << "After dce:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
|
||||
@ -335,9 +335,31 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
std::is_integral<NativeT>::value &&
|
||||
!std::is_same<NativeT, bool>::value>::type* = nullptr>
|
||||
Status HandleNot(HloInstruction* not_) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
|
||||
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
|
||||
return ~elem_operand;
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT, typename std::enable_if<std::is_floating_point<
|
||||
NativeT>::value>::type* = nullptr>
|
||||
Status HandleNot(HloInstruction* not_) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
|
||||
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
|
||||
return !elem_operand;
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<std::is_same<NativeT, bool>::value>::type* =
|
||||
nullptr>
|
||||
Status HandleNot(HloInstruction* not_) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
|
||||
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
|
||||
@ -357,7 +379,24 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleNot<ReturnT>(not_);
|
||||
}
|
||||
|
||||
Status HandleNegate(HloInstruction* negate) override {
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
std::is_signed<NativeT>::value &&
|
||||
!std::is_floating_point<NativeT>::value>::type* = nullptr>
|
||||
Status HandleNegate(HloInstruction* negate) {
|
||||
using type = typename std::make_unsigned<NativeT>::type;
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate],
|
||||
ElementWiseUnaryOp(negate, [](ReturnT elem_operand) {
|
||||
return NativeT(-type(elem_operand));
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
!std::is_signed<NativeT>::value ||
|
||||
std::is_floating_point<NativeT>::value>::type* = nullptr>
|
||||
Status HandleNegate(HloInstruction* negate) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate],
|
||||
ElementWiseUnaryOp(negate, [](ReturnT elem_operand) {
|
||||
return -elem_operand;
|
||||
@ -365,6 +404,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleNegate(HloInstruction* negate) override {
|
||||
return HandleNegate<ReturnT>(negate);
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
@ -402,7 +445,26 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleMultiply(HloInstruction* multiply) override {
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
std::is_signed<NativeT>::value &&
|
||||
!std::is_floating_point<NativeT>::value>::type* = nullptr>
|
||||
Status HandleMultiply(HloInstruction* multiply) {
|
||||
using type = typename std::make_unsigned<NativeT>::type;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[multiply],
|
||||
ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) {
|
||||
return NativeT(type(lhs_elem) * type(rhs_elem));
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<std::is_unsigned<NativeT>::value ||
|
||||
std::is_floating_point<NativeT>::value ||
|
||||
is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleMultiply(HloInstruction* multiply) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[multiply],
|
||||
ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) {
|
||||
@ -411,6 +473,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleMultiply(HloInstruction* multiply) override {
|
||||
return HandleMultiply<ReturnT>(multiply);
|
||||
}
|
||||
|
||||
Status HandleSubtract(HloInstruction* subtract) override {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[subtract],
|
||||
@ -516,9 +582,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleRemainder<ReturnT>(remainder);
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
|
||||
nullptr>
|
||||
Status HandleAnd(HloInstruction* and_) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[and_],
|
||||
ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el & rhs_el;
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT, typename std::enable_if<std::is_floating_point<
|
||||
NativeT>::value>::type* = nullptr>
|
||||
Status HandleAnd(HloInstruction* and_) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[and_],
|
||||
@ -539,9 +616,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleAnd<ReturnT>(and_);
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
|
||||
nullptr>
|
||||
Status HandleOr(HloInstruction* or_) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[or_],
|
||||
ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el | rhs_el;
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT, typename std::enable_if<std::is_floating_point<
|
||||
NativeT>::value>::type* = nullptr>
|
||||
Status HandleOr(HloInstruction* or_) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[or_],
|
||||
@ -645,7 +733,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleClamp(HloInstruction* clamp) {
|
||||
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
|
||||
[](ReturnT low, ReturnT high, ReturnT value) {
|
||||
[](ReturnT low, ReturnT value, ReturnT high) {
|
||||
return std::fmax(low, std::fmin(value, high));
|
||||
};
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
|
||||
@ -814,13 +902,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
rhs_index[dnums.kernel_spatial_dimensions(ki)] =
|
||||
rhs_spatial_index[ki];
|
||||
window_dim.window_reversal()
|
||||
? ((window_dim.size() - 1) - rhs_spatial_index[ki])
|
||||
: rhs_spatial_index[ki];
|
||||
}
|
||||
|
||||
result_val += lhs_literal.Get<ReturnT>(lhs_index) *
|
||||
rhs_literal.Get<ReturnT>(rhs_index);
|
||||
}
|
||||
cnt:;
|
||||
cnt : {}
|
||||
} while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
|
||||
|
||||
return result_val;
|
||||
@ -1287,6 +1377,50 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT, typename std::enable_if<std::is_floating_point<
|
||||
NativeT>::value>::type* = nullptr>
|
||||
Status HandleSin(HloInstruction* sin) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin],
|
||||
ElementWiseUnaryOp(sin, [](ReturnT elem_operand) {
|
||||
return std::sin(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<std::is_integral<NativeT>::value ||
|
||||
is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleSin(HloInstruction* sin) {
|
||||
return InvalidArgument("Unsupported type for Sin");
|
||||
}
|
||||
|
||||
Status HandleSin(HloInstruction* sin) override {
|
||||
return HandleSin<ReturnT>(sin);
|
||||
}
|
||||
|
||||
template <typename NativeT, typename std::enable_if<std::is_floating_point<
|
||||
NativeT>::value>::type* = nullptr>
|
||||
Status HandleCos(HloInstruction* cos) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos],
|
||||
ElementWiseUnaryOp(cos, [](ReturnT elem_operand) {
|
||||
return std::cos(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<std::is_integral<NativeT>::value ||
|
||||
is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleCos(HloInstruction* cos) {
|
||||
return InvalidArgument("Unsupported type for Cos");
|
||||
}
|
||||
|
||||
Status HandleCos(HloInstruction* cos) override {
|
||||
return HandleCos<ReturnT>(cos);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename IndexT>
|
||||
StatusOr<std::unique_ptr<Literal>> DynamicSlice(
|
||||
@ -1397,8 +1531,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const auto* rhs = instruction->operand(1);
|
||||
const auto* ehs = instruction->operand(2);
|
||||
|
||||
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
|
||||
// removed.
|
||||
// TODO(b/35950897, b/27796129): add DCHECK back once implicit
|
||||
// broadcast is removed.
|
||||
if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) &&
|
||||
ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) &&
|
||||
ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) {
|
||||
@ -1565,6 +1699,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> operands;
|
||||
operands.reserve(owned_operands.size());
|
||||
for (auto& operand : owned_operands) {
|
||||
operands.push_back(operand.get());
|
||||
}
|
||||
|
||||
@ -46,20 +46,57 @@ class HloEvaluatorTest : public HloVerifiedTestBase {
|
||||
HloEvaluatorTest() { evaluator_ = MakeUnique<HloEvaluator>(); }
|
||||
|
||||
std::unique_ptr<HloEvaluator> evaluator_;
|
||||
|
||||
void TestUnaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
|
||||
std::unique_ptr<Literal> input, float aabs = 0) {
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1 =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
|
||||
auto instruction = b.AddInstruction(
|
||||
HloInstruction::CreateUnary(expected->shape(), opcode, c1));
|
||||
module().AddEntryComputation(b.Build());
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
|
||||
|
||||
auto element_type = expected->shape().element_type();
|
||||
if (element_type == F32 || element_type == F64) {
|
||||
ErrorSpec error(aabs);
|
||||
LiteralTestUtil::ExpectNear(*expected, *result, error);
|
||||
} else {
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
}
|
||||
|
||||
void TestBinaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
|
||||
std::unique_ptr<Literal> lhs,
|
||||
std::unique_ptr<Literal> rhs) {
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
|
||||
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
|
||||
auto instruction = b.AddInstruction(
|
||||
HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2));
|
||||
module().AddEntryComputation(b.Build());
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
};
|
||||
|
||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
|
||||
// with 3 operands.
|
||||
TEST_F(HloEvaluatorTest, DoesClamp) {
|
||||
auto low = Literal::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
|
||||
auto high = Literal::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
|
||||
auto value = Literal::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
|
||||
auto high = Literal::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
|
||||
|
||||
Shape shape = low->shape();
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
|
||||
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
|
||||
auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
|
||||
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
|
||||
auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
|
||||
auto instruction = b.AddInstruction(
|
||||
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
|
||||
module().AddEntryComputation(b.Build());
|
||||
@ -72,6 +109,28 @@ TEST_F(HloEvaluatorTest, DoesClamp) {
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
|
||||
auto low = Literal::CreateR0<float>(0.f);
|
||||
auto value = Literal::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
|
||||
auto high = Literal::CreateR0<float>(1.f);
|
||||
|
||||
Shape shape = value->shape();
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
|
||||
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
|
||||
auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
|
||||
auto instruction = b.AddInstruction(
|
||||
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
|
||||
module().AddEntryComputation(b.Build());
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
|
||||
// with 3 operands.
|
||||
TEST_F(HloEvaluatorTest, DoesSelect) {
|
||||
@ -103,120 +162,101 @@ TEST_F(HloEvaluatorTest, DoesSelect) {
|
||||
TEST_F(HloEvaluatorTest, DoesAdd) {
|
||||
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
|
||||
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
|
||||
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
|
||||
auto instruction = b.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1, c2));
|
||||
module().AddEntryComputation(b.Build());
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected = Literal::CreateR2<int64>({{3, 4}, {-96, 8}});
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
|
||||
std::move(rhs));
|
||||
}
|
||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs
|
||||
// element-wise and with 2 operands.
|
||||
TEST_F(HloEvaluatorTest, DoesAnd) {
|
||||
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
|
||||
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
|
||||
auto expected = Literal::CreateR2<int64>({{0, 0}, {4, 4}});
|
||||
TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
|
||||
std::move(rhs));
|
||||
}
|
||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs
|
||||
// element-wise or with 2 operands.
|
||||
TEST_F(HloEvaluatorTest, DoesOr) {
|
||||
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
|
||||
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
|
||||
auto expected = Literal::CreateR2<int64>({{3, 4}, {-100, 4}});
|
||||
TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
|
||||
std::move(rhs));
|
||||
}
|
||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs
|
||||
// element-wise multiply with 2 operands.
|
||||
TEST_F(HloEvaluatorTest, DoesMultiply) {
|
||||
auto lhs = Literal::CreateR2<int32>({{-1, 0}, {-100, 4}});
|
||||
auto rhs = Literal::CreateR2<int32>(
|
||||
{{std::numeric_limits<int32>::min(), 4}, {4, 4}});
|
||||
auto expected = Literal::CreateR2<int32>(
|
||||
{{std::numeric_limits<int32>::min(), 0}, {-400, 16}});
|
||||
TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
|
||||
std::move(rhs));
|
||||
}
|
||||
|
||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs
|
||||
// element-wise divide with 2 operands.
|
||||
TEST_F(HloEvaluatorTest, DoesDivideInt64) {
|
||||
auto lhs_s64 = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
|
||||
auto rhs_s64 = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
|
||||
|
||||
Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2});
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1_s64 =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_s64)));
|
||||
auto c2_s64 =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_s64)));
|
||||
auto instruction = b.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape_s64, HloOpcode::kDivide, c1_s64, c2_s64));
|
||||
module().AddEntryComputation(b.Build());
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
|
||||
|
||||
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
|
||||
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
|
||||
auto expected = Literal::CreateR2<int64>({{0, 0}, {-25, 1}});
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
|
||||
std::move(rhs));
|
||||
}
|
||||
TEST_F(HloEvaluatorTest, DoesDivideDouble) {
|
||||
auto lhs_f64 = Literal::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
|
||||
auto rhs_f64 = Literal::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
|
||||
|
||||
Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2});
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1_f64 =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_f64)));
|
||||
auto c2_f64 =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_f64)));
|
||||
auto instruction = b.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape_f64, HloOpcode::kDivide, c1_f64, c2_f64));
|
||||
module().AddEntryComputation(b.Build());
|
||||
|
||||
auto result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
|
||||
|
||||
auto lhs = Literal::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
|
||||
auto rhs = Literal::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
|
||||
auto expected =
|
||||
Literal::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
|
||||
std::move(rhs));
|
||||
}
|
||||
|
||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs
|
||||
// element-wise abs op with 1 operand.
|
||||
TEST_F(HloEvaluatorTest, DoesAbsR2) {
|
||||
auto operand = Literal::CreateR2<int64>({{1, -20}, {-100, 4}});
|
||||
const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2});
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1 =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
|
||||
auto instruction =
|
||||
b.AddInstruction(HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1));
|
||||
module().AddEntryComputation(b.Build());
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected = Literal::CreateR2<int64>({{1, 20}, {100, 4}});
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
|
||||
}
|
||||
TEST_F(HloEvaluatorTest, DoesAbsR0) {
|
||||
// For R0 literal.
|
||||
const Shape& r0 = ShapeUtil::MakeShape(F32, {});
|
||||
auto operand = Literal::CreateR0<float>(-1.0f);
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1 =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
|
||||
auto instruction =
|
||||
b.AddInstruction(HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1));
|
||||
module().AddEntryComputation(b.Build());
|
||||
|
||||
auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie();
|
||||
auto expected = Literal::CreateR0<float>(1.0f);
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
|
||||
}
|
||||
TEST_F(HloEvaluatorTest, DoesAbsR1WithZeroSize) {
|
||||
// For R1 literal with dimension of size 0.
|
||||
Shape empty_r1 = ShapeUtil::MakeShape(F32, {0});
|
||||
auto operand = Literal::CreateR1<float>({});
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1 =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
|
||||
auto instruction = b.AddInstruction(
|
||||
HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1));
|
||||
module().AddEntryComputation(b.Build());
|
||||
|
||||
auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie();
|
||||
auto expected = Literal::CreateR1<float>({});
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
|
||||
}
|
||||
TEST_F(HloEvaluatorTest, DoesNegateR2) {
|
||||
auto operand = Literal::CreateR2<int32>(
|
||||
{{0, std::numeric_limits<int32>::min()}, {-1, 4}});
|
||||
auto expected =
|
||||
Literal::CreateR2<int32>({{0, std::numeric_limits<int>::min()}, {1, -4}});
|
||||
TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
|
||||
}
|
||||
TEST_F(HloEvaluatorTest, DoesCosR2) {
|
||||
auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
|
||||
auto expected = Literal::CreateR2<float>({{1, -1}, {-1, 1}});
|
||||
TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand));
|
||||
}
|
||||
TEST_F(HloEvaluatorTest, DoesSinR2) {
|
||||
auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
|
||||
auto expected = Literal::CreateR2<float>({{0, 0}, {0, 0}});
|
||||
TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
|
||||
0x1.0P-20);
|
||||
}
|
||||
TEST_F(HloEvaluatorTest, DoesNotR2) {
|
||||
auto operand =
|
||||
Literal::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
|
||||
{-1, std::numeric_limits<int>::max()}});
|
||||
auto expected =
|
||||
Literal::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
|
||||
{0, std::numeric_limits<int>::min()}});
|
||||
TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
|
||||
}
|
||||
|
||||
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
|
||||
// constant operands.
|
||||
TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
|
||||
@ -794,6 +834,83 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
|
||||
HloComputation::Builder b(TestName());
|
||||
|
||||
// clang-format off
|
||||
// Input dimensions: [feature=2, height=3, batch=1, width=4]
|
||||
Array4D<float> input({
|
||||
{{{1, 2, 3, 4}},
|
||||
{{5, 6, 7, 8}},
|
||||
{{9, 10, 11, 12}}},
|
||||
{{{13, 14, 15, 16}},
|
||||
{{17, 18, 19, 20}},
|
||||
{{21, 22, 23, 24}}}
|
||||
});
|
||||
// Weight dimensions:
|
||||
// [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
|
||||
Array4D<float> weight({{
|
||||
{{1, 7, 13},
|
||||
{4, 10, 16}},
|
||||
{{2, 8, 14},
|
||||
{5, 11, 17}},
|
||||
{{3, 9, 15},
|
||||
{6, 12, 18}}
|
||||
}});
|
||||
// clang-format on
|
||||
|
||||
auto lhs_literal = Literal::CreateR4FromArray4D<float>(input);
|
||||
HloInstruction* lhs_instruction =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
|
||||
|
||||
auto rhs_literal = Literal::CreateR4FromArray4D<float>(weight);
|
||||
HloInstruction* rhs_instruction =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
|
||||
rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
|
||||
rhs_instruction->shape(), rhs_instruction, {3, 1}));
|
||||
|
||||
Window window;
|
||||
WindowDimension dim;
|
||||
dim.set_size(3);
|
||||
dim.set_stride(1);
|
||||
dim.set_padding_low(0);
|
||||
dim.set_padding_high(0);
|
||||
dim.set_window_dilation(1);
|
||||
dim.set_base_dilation(1);
|
||||
dim.set_window_reversal(true);
|
||||
*window.add_dimensions() = dim;
|
||||
*window.add_dimensions() = dim;
|
||||
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(2);
|
||||
dnums.set_output_batch_dimension(2);
|
||||
dnums.set_input_feature_dimension(0);
|
||||
dnums.set_output_feature_dimension(0);
|
||||
dnums.add_spatial_dimensions(1);
|
||||
dnums.add_spatial_dimensions(3);
|
||||
|
||||
dnums.set_kernel_output_feature_dimension(0);
|
||||
dnums.set_kernel_input_feature_dimension(2);
|
||||
dnums.add_kernel_spatial_dimensions(3);
|
||||
dnums.add_kernel_spatial_dimensions(1);
|
||||
|
||||
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
|
||||
b.AddInstruction(HloInstruction::CreateConvolve(
|
||||
shape, lhs_instruction, rhs_instruction, window, dnums));
|
||||
auto computation = module().AddEntryComputation(b.Build());
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie();
|
||||
|
||||
// clang-format off
|
||||
// Result dimensions: [feature=1, height=1, batch=1, width=2]
|
||||
Array4D<float> expected_array({{{{2514, 2685}}}});
|
||||
// clang-format on
|
||||
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) {
|
||||
HloComputation::Builder b(TestName());
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter(
|
||||
instruction_info->transcendental_count =
|
||||
cost_analysis.transcendental_count(*hlo);
|
||||
instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo);
|
||||
instruction_info->seconds = cost_analysis.seconds(*hlo);
|
||||
instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo);
|
||||
instruction_info->profile_index =
|
||||
hlo_profile_index_map.GetProfileIndexFor(*hlo);
|
||||
CHECK_LT(instruction_info->profile_index, max_profile_index);
|
||||
|
||||
@ -90,10 +90,10 @@ TEST_F(HloExecutionProfileTest, Basic) {
|
||||
const std::vector<string>& line_3 = lines_and_words[3];
|
||||
|
||||
EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles));
|
||||
EXPECT_EQ(line_2[kInstructionNameIndex], dot_instruction->name());
|
||||
EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name());
|
||||
|
||||
EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles));
|
||||
EXPECT_EQ(line_3[kInstructionNameIndex], add_instruction->name());
|
||||
EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name());
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
@ -864,9 +864,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
// (eg, parameter).
|
||||
switch (instr->opcode()) {
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kBitcastConvert:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kComplex:
|
||||
@ -882,18 +883,19 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical:
|
||||
@ -903,7 +905,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kTanh:
|
||||
case HloOpcode::kRng:
|
||||
// De-emphasize scalar-shaped elementwise ops -- they're generally
|
||||
// uninteresting.
|
||||
if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
|
||||
@ -911,9 +912,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
}
|
||||
return kYellow;
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kTuple:
|
||||
return kWhite;
|
||||
case HloOpcode::kBroadcast:
|
||||
// De-emphasize nodes which broadcast a scalar within a fusion node --
|
||||
@ -952,28 +953,28 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
return kRed;
|
||||
case HloOpcode::kParameter:
|
||||
return kParameterColor;
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
case HloOpcode::kReduceWindow:
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
return kPurple;
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kMap:
|
||||
return kGray;
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kCrossReplicaSum:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kCrossReplicaSum:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
return kBrown;
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kWhile:
|
||||
case HloOpcode::kCall:
|
||||
return kDarkGreen;
|
||||
case HloOpcode::kConstant:
|
||||
LOG(FATAL) << "Constants don't get their own nodes in the graph.";
|
||||
|
||||
@ -52,7 +52,9 @@ using ::tensorflow::strings::StrCat;
|
||||
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
HloModule* module, const HloInstructionProto& proto,
|
||||
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
|
||||
tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map) {
|
||||
const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
|
||||
const std::function<void(std::unique_ptr<HloComputation>)>&
|
||||
add_fused_computation) {
|
||||
TF_RET_CHECK(!proto.opcode().empty());
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
|
||||
TF_RET_CHECK(proto.has_shape());
|
||||
@ -78,19 +80,19 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_RET_CHECK(!proto.fusion_kind().empty());
|
||||
TF_ASSIGN_OR_RETURN(instruction->fusion_kind_,
|
||||
StringToFusionKind(proto.fusion_kind()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloComputation> fused_computation,
|
||||
HloComputation::CreateFromProto(
|
||||
module, proto.fused_instructions_computation(), computation_map,
|
||||
/*fusion_instruction=*/instruction.get()));
|
||||
instruction->called_computations_.push_back(
|
||||
module->AddEmbeddedComputation(std::move(fused_computation)));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> fused_computation,
|
||||
HloComputation::CreateFromProto(
|
||||
module, proto.fused_instructions_computation(),
|
||||
computation_map, add_fused_computation,
|
||||
/*fusion_instruction=*/instruction.get()));
|
||||
instruction->called_computations_.push_back(fused_computation.get());
|
||||
add_fused_computation(std::move(fused_computation));
|
||||
} else {
|
||||
for (const string& computation_name : proto.called_computation_names()) {
|
||||
TF_RET_CHECK(ContainsKey(*computation_map, computation_name))
|
||||
TF_RET_CHECK(ContainsKey(computation_map, computation_name))
|
||||
<< "No computation named " << computation_name;
|
||||
instruction->called_computations_.push_back(
|
||||
computation_map->at(computation_name));
|
||||
computation_map.at(computation_name));
|
||||
}
|
||||
}
|
||||
|
||||
@ -149,7 +151,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
WrapUnique(new HloInstruction(HloOpcode::kParameter, shape));
|
||||
instruction->parameter_number_ = parameter_number;
|
||||
instruction->parameter_name_ = name;
|
||||
instruction->name_ = "%" + name;
|
||||
instruction->name_ = name;
|
||||
return instruction;
|
||||
}
|
||||
|
||||
@ -436,6 +438,23 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
|
||||
const Shape& shape, HloInstruction* pred,
|
||||
HloInstruction* true_computation_arg, HloComputation* true_computation,
|
||||
HloInstruction* false_computation_arg, HloComputation* false_computation) {
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
|
||||
instruction->AppendOperand(pred);
|
||||
instruction->AppendOperand(true_computation_arg);
|
||||
instruction->AppendOperand(false_computation_arg);
|
||||
// In called_computations_, the index of true_computation must be 0 and that
|
||||
// of false computation must be 1, as defined by kTrueComputationIndex and
|
||||
// kFalseComputationIndex.
|
||||
instruction->called_computations_.push_back(true_computation);
|
||||
instruction->called_computations_.push_back(false_computation);
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
@ -499,6 +518,15 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateBitcastConvert(const Shape& shape,
|
||||
HloInstruction* operand) {
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
|
||||
instruction->AppendOperand(operand);
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
|
||||
const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
|
||||
@ -631,7 +659,10 @@ HloInstruction::CreateSelectAndScatter(
|
||||
CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
|
||||
CHECK(std::equal(operand->shape().dimensions().begin(),
|
||||
operand->shape().dimensions().end(),
|
||||
Permute(dimensions, shape.dimensions()).begin()));
|
||||
Permute(dimensions, shape.dimensions()).begin()))
|
||||
<< "shape: " << ShapeUtil::HumanString(shape)
|
||||
<< ", operand->shape(): " << ShapeUtil::HumanString(shape)
|
||||
<< ", dimensions: {" << Join(dimensions, ", ") << "}";
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape));
|
||||
instruction->AppendOperand(operand);
|
||||
@ -791,7 +822,7 @@ HloInstruction* HloInstruction::FuseInstructionInternal(
|
||||
HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
HloInstruction* instruction_to_fuse, bool add_output) {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(instruction_to_fuse->IsFusable());
|
||||
CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString();
|
||||
VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
|
||||
HloInstruction* clone = nullptr;
|
||||
if (called_computations_.empty()) {
|
||||
@ -869,10 +900,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
// parameter instruction.
|
||||
int64 param_no = fused_parameters.size();
|
||||
// Name the parameter after the instruction it represents in the outer
|
||||
// (non-fusion) computation. Strip the leading "%" from the operand name
|
||||
// to avoid a double %%.
|
||||
string param_name =
|
||||
StrCat(operand->name().substr(1), ".param_", param_no);
|
||||
// (non-fusion) computation.
|
||||
string param_name = StrCat(operand->name(), ".param_", param_no);
|
||||
fused_param = fused_instructions_computation()->AddParameter(
|
||||
CreateParameter(param_no, operand->shape(), param_name));
|
||||
AppendOperand(operand);
|
||||
@ -1013,7 +1042,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
|
||||
VLOG(3) << " new operands:";
|
||||
for (const HloInstruction* new_operand : new_operands) {
|
||||
VLOG(3) << " " << new_operand->name();
|
||||
VLOG(3) << " %" << new_operand->name();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> clone;
|
||||
@ -1095,6 +1124,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
CHECK_EQ(new_operands.size(), 1);
|
||||
clone = CreateConvert(shape, new_operands[0]);
|
||||
break;
|
||||
case HloOpcode::kBitcastConvert:
|
||||
CHECK_EQ(new_operands.size(), 1);
|
||||
clone = CreateBitcastConvert(shape, new_operands[0]);
|
||||
break;
|
||||
case HloOpcode::kReducePrecision:
|
||||
CHECK_EQ(new_operands.size(), 1);
|
||||
clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_,
|
||||
@ -1535,6 +1568,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
// A convert result is determined by the primitive type that the operand is
|
||||
// converted into.
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kBitcastConvert:
|
||||
return shape().element_type() == other.shape().element_type();
|
||||
|
||||
// A reduce-precision operation is determined by the bit sizes.
|
||||
@ -1814,6 +1848,32 @@ void HloInstruction::set_scatter(HloComputation* computation) {
|
||||
called_computations_[kScatterComputationIndex] = computation;
|
||||
}
|
||||
|
||||
HloComputation* HloInstruction::true_computation() const {
|
||||
CHECK_EQ(HloOpcode::kConditional, opcode_);
|
||||
return called_computations_[kTrueComputationIndex];
|
||||
}
|
||||
|
||||
HloComputation* HloInstruction::false_computation() const {
|
||||
CHECK_EQ(HloOpcode::kConditional, opcode_);
|
||||
return called_computations_[kFalseComputationIndex];
|
||||
}
|
||||
|
||||
void HloInstruction::set_true_computation(HloComputation* true_computation) {
|
||||
// Don't allow changing the computation for fused instructions so we don't
|
||||
// have to recompute called_instructions for the entire fusion instruction.
|
||||
CHECK(!IsFused());
|
||||
CHECK_EQ(HloOpcode::kConditional, opcode_);
|
||||
called_computations_[kTrueComputationIndex] = true_computation;
|
||||
}
|
||||
|
||||
void HloInstruction::set_false_computation(HloComputation* false_computation) {
|
||||
// Don't allow changing the computation for fused instructions so we don't
|
||||
// have to recompute called_instructions for the entire fusion instruction.
|
||||
CHECK(!IsFused());
|
||||
CHECK_EQ(HloOpcode::kConditional, opcode_);
|
||||
called_computations_[kFalseComputationIndex] = false_computation;
|
||||
}
|
||||
|
||||
string HloInstruction::SignatureString() const {
|
||||
string operands =
|
||||
Join(operands_, ", ", [](string* out, HloInstruction* operand) {
|
||||
@ -1825,7 +1885,7 @@ string HloInstruction::SignatureString() const {
|
||||
string HloInstruction::ToString(bool compact_operands, bool include_metadata,
|
||||
bool include_large_constants) const {
|
||||
string result =
|
||||
StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
|
||||
StrCat("%", name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
|
||||
HloOpcodeString(opcode()), "(",
|
||||
OperandsToString(compact_operands, include_large_constants), ")");
|
||||
for (const string& extra : ExtraAttributesToString()) {
|
||||
@ -1877,7 +1937,7 @@ string HloInstruction::OperandsToString(bool compact,
|
||||
operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
|
||||
*out += ShapeUtil::HumanStringWithLayout(operand->shape());
|
||||
if (!compact) {
|
||||
StrAppend(out, " ", operand->name());
|
||||
StrAppend(out, " %", operand->name());
|
||||
}
|
||||
});
|
||||
const int64 remaining = operands_.size() - slice.size();
|
||||
@ -1896,7 +1956,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
||||
if (CanHaveDimensionsField()) {
|
||||
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
|
||||
}
|
||||
if (window_ != nullptr) {
|
||||
if (window_ != nullptr && window_->dimensions_size() != 0) {
|
||||
extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
|
||||
}
|
||||
if (padding_config_ != nullptr) {
|
||||
@ -1964,7 +2024,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
||||
extra.push_back(StrCat("control-predecessors={",
|
||||
Join(control_predecessors_, ", ",
|
||||
[](string* out, HloInstruction* pre) {
|
||||
StrAppend(out, pre->name());
|
||||
StrAppend(out, "%", pre->name());
|
||||
}),
|
||||
"}"));
|
||||
}
|
||||
@ -1979,10 +2039,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
||||
}
|
||||
|
||||
string HloInstruction::ToShortString() const {
|
||||
return StrCat(name(), " = ", HloOpcodeString(opcode()), "(",
|
||||
return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
|
||||
Join(operands_, ", ",
|
||||
[](string* out, HloInstruction* operand) {
|
||||
StrAppend(out, operand->name());
|
||||
StrAppend(out, "%", operand->name());
|
||||
}),
|
||||
")");
|
||||
}
|
||||
@ -2076,8 +2136,10 @@ string HloInstruction::ToCategory() const {
|
||||
bool saw_rank_1 = false;
|
||||
bool saw_higher_rank = false;
|
||||
for (const auto* operand : operands()) {
|
||||
saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1;
|
||||
saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1;
|
||||
if (!ShapeUtil::IsTuple(operand->shape())) {
|
||||
saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1;
|
||||
saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1;
|
||||
}
|
||||
}
|
||||
if (saw_rank_1 && saw_higher_rank) {
|
||||
return "rank-1-broadcast binary fusion";
|
||||
@ -2130,25 +2192,13 @@ bool HloInstruction::IsFusable() const {
|
||||
if (tracing()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Some kinds of instructions don't make sense to fuse.
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
return false;
|
||||
// Only fuse Rng if it is used once, otherwise the random numbers generated
|
||||
// will be different in each fusion. If it is the root (user count = 0)
|
||||
// then it is the equivalent of having one user.
|
||||
case HloOpcode::kRng:
|
||||
return users_.size() <= 1;
|
||||
// Side effecting instrutions cannot be fused.
|
||||
default:
|
||||
return true;
|
||||
return !HasSideEffect();
|
||||
}
|
||||
}
|
||||
|
||||
@ -2199,7 +2249,7 @@ HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
|
||||
: unique_id_(-1),
|
||||
opcode_(opcode),
|
||||
shape_(shape),
|
||||
name_("%" + HloOpcodeString(opcode)) {
|
||||
name_(HloOpcodeString(opcode)) {
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
|
||||
}
|
||||
|
||||
@ -2259,6 +2309,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleConcatenate(this);
|
||||
case HloOpcode::kConvert:
|
||||
return visitor->HandleConvert(this);
|
||||
case HloOpcode::kBitcastConvert:
|
||||
return visitor->HandleBitcastConvert(this);
|
||||
case HloOpcode::kCopy:
|
||||
return visitor->HandleCopy(this);
|
||||
case HloOpcode::kMultiply:
|
||||
@ -2345,6 +2397,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleFusion(this);
|
||||
case HloOpcode::kCall:
|
||||
return visitor->HandleCall(this);
|
||||
case HloOpcode::kConditional:
|
||||
return visitor->HandleConditional(this);
|
||||
case HloOpcode::kCustomCall:
|
||||
return visitor->HandleCustomCall(this);
|
||||
case HloOpcode::kRecv:
|
||||
@ -2357,7 +2411,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleSendDone(this);
|
||||
|
||||
// These opcodes are not handled here.
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kTrace:
|
||||
break;
|
||||
}
|
||||
@ -2423,7 +2476,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
|
||||
visitor->GetVisitState(current_id);
|
||||
if (visit_state == Visitor::kVisited) {
|
||||
dfs_stack.pop_back();
|
||||
VLOG(3) << "Not visiting HLO " << current_node->name()
|
||||
VLOG(3) << "Not visiting HLO %" << current_node->name()
|
||||
<< " as it was already visited.";
|
||||
continue;
|
||||
}
|
||||
@ -2432,7 +2485,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
|
||||
dfs_stack.pop_back();
|
||||
|
||||
TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
|
||||
VLOG(2) << "Visiting HLO " << current_node->name();
|
||||
VLOG(2) << "Visiting HLO %" << current_node->name();
|
||||
TF_RETURN_IF_ERROR(current_node->Visit(visitor));
|
||||
visitor->SetVisitState(current_id, Visitor::kVisited);
|
||||
TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
|
||||
@ -2477,7 +2530,7 @@ template <typename HloInstructionPtr>
|
||||
Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
|
||||
bool call_finish_visit,
|
||||
bool ignore_control_predecessors) {
|
||||
VLOG(3) << "HloInstruction::Accept(" << name() << ")";
|
||||
VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
|
||||
TF_RETURN_IF_ERROR(
|
||||
PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
|
||||
if (call_finish_visit) {
|
||||
@ -2493,7 +2546,7 @@ template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
|
||||
Status HloInstruction::AcceptWithOperandOrder(
|
||||
DfsHloVisitor* visitor, const CompareFunction& operand_order,
|
||||
bool call_finish_visit) {
|
||||
VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")";
|
||||
VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
|
||||
InternalCompareFunction func = [&operand_order](
|
||||
std::pair<int, const HloInstruction*> a,
|
||||
std::pair<int, const HloInstruction*> b) {
|
||||
@ -2556,7 +2609,7 @@ Status HloInstruction::Accept(
|
||||
|
||||
Status HloInstruction::AcceptOrdered(
|
||||
DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order) {
|
||||
VLOG(2) << "HloInstruction::AcceptOrdered(" << name() << ")";
|
||||
VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")";
|
||||
TF_RET_CHECK(OrderIsTopologicalSort(order));
|
||||
|
||||
// Compute the predecessors of this instruction.
|
||||
@ -2575,7 +2628,7 @@ Status HloInstruction::AcceptOrdered(
|
||||
// The visitor can mark instructions as visited to skip particular
|
||||
// instructions.
|
||||
if (visitor->DidVisit(*const_instruction)) {
|
||||
VLOG(3) << "Not visiting HLO " << const_instruction->name()
|
||||
VLOG(3) << "Not visiting HLO %" << const_instruction->name()
|
||||
<< " as it was already visited.";
|
||||
continue;
|
||||
}
|
||||
@ -2584,7 +2637,7 @@ Status HloInstruction::AcceptOrdered(
|
||||
const_cast<HloInstruction*>(const_instruction);
|
||||
|
||||
TF_RETURN_IF_ERROR(visitor->Preprocess(instruction));
|
||||
VLOG(2) << "Visiting HLO " << instruction->name();
|
||||
VLOG(2) << "Visiting HLO %" << instruction->name();
|
||||
TF_RETURN_IF_ERROR(instruction->Visit(visitor));
|
||||
visitor->SetVisited(*instruction);
|
||||
TF_RETURN_IF_ERROR(visitor->Postprocess(instruction));
|
||||
@ -2630,6 +2683,7 @@ bool HloInstruction::IsElementwise() const {
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kBitcastConvert:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kExp:
|
||||
|
||||
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/iterator_range.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -83,12 +84,16 @@ class HloInstruction {
|
||||
// must contain all operands of the newly constructed instruction.
|
||||
// computation_map: a map from computation name to HloComputation*. This map
|
||||
// must contain all computations which the newly constructed instruction
|
||||
// calls. If the instruction is a fusion instruction, then the fusion
|
||||
// computation is added to this map and the module.
|
||||
// calls.
|
||||
// add_fused_computation: A function to call to add a fused
|
||||
// computation. Used (clearly) when the instruction is a fusion
|
||||
// instruction.
|
||||
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
|
||||
HloModule* module, const HloInstructionProto& proto,
|
||||
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
|
||||
tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map);
|
||||
const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
|
||||
const std::function<void(std::unique_ptr<HloComputation>)>&
|
||||
add_fused_computation);
|
||||
|
||||
// Creates a parameter-retrieving instruction.
|
||||
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
|
||||
@ -171,6 +176,11 @@ class HloInstruction {
|
||||
static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
|
||||
HloInstruction* operand);
|
||||
|
||||
// Creates a bitcast conversion instruction, where operand is the data to
|
||||
// convert and shape is the target shape for the conversion.
|
||||
static std::unique_ptr<HloInstruction> CreateBitcastConvert(
|
||||
const Shape& shape, HloInstruction* operand);
|
||||
|
||||
// Creates an infeed instruction, which reads data of the given shape from the
|
||||
// Infeed interface of the device.
|
||||
static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
|
||||
@ -305,6 +315,11 @@ class HloInstruction {
|
||||
HloComputation* body,
|
||||
HloInstruction* init);
|
||||
|
||||
static std::unique_ptr<HloInstruction> CreateConditional(
|
||||
const Shape& shape, HloInstruction* pred,
|
||||
HloInstruction* true_computation_arg, HloComputation* true_computation,
|
||||
HloInstruction* false_computation_arg, HloComputation* false_computation);
|
||||
|
||||
// Creates a fusion instruction. A fusion instruction contains one or more
|
||||
// fused instructions forming an expression with a single root
|
||||
// "fused_root". Additional instructions can be added to the fusion
|
||||
@ -608,6 +623,15 @@ class HloInstruction {
|
||||
void set_select(HloComputation* select);
|
||||
void set_scatter(HloComputation* scatter);
|
||||
|
||||
// Gets/sets the true and false HloComputation for Conditional. The setters
|
||||
// should only be called by HloModule or HloComputation methods.
|
||||
//
|
||||
// Precondition: The instruction is a Conditional instruction.
|
||||
HloComputation* true_computation() const;
|
||||
HloComputation* false_computation() const;
|
||||
void set_true_computation(HloComputation* true_computation);
|
||||
void set_false_computation(HloComputation* false_computation);
|
||||
|
||||
// Returns a string for the signature of this instruction if considered as a
|
||||
// function, e.g. the signature of an F32 add is (F32, F32) -> F32.
|
||||
string SignatureString() const;
|
||||
@ -1192,6 +1216,10 @@ class HloInstruction {
|
||||
// kSelectAndScatter computations.
|
||||
kSelectComputationIndex = 0,
|
||||
kScatterComputationIndex = 1,
|
||||
|
||||
// kConditional computations.
|
||||
kTrueComputationIndex = 0,
|
||||
kFalseComputationIndex = 1,
|
||||
};
|
||||
|
||||
// Outfeed configuration information, only present for kOutfeed.
|
||||
|
||||
@ -1138,35 +1138,34 @@ TEST_F(HloInstructionTest, CloneSuffixNames) {
|
||||
// Test cloning the same instruction multiple times.
|
||||
auto foo =
|
||||
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo");
|
||||
EXPECT_EQ(foo->Clone()->name(), "%foo.clone");
|
||||
EXPECT_EQ(foo->Clone()->Clone()->name(), "%foo.clone2");
|
||||
EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "%foo.clone3");
|
||||
EXPECT_EQ(foo->Clone()->name(), "foo.clone");
|
||||
EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2");
|
||||
EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3");
|
||||
|
||||
// Test custom suffixes.
|
||||
EXPECT_EQ(foo->Clone("bar")->name(), "%foo.bar");
|
||||
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "%foo.bar2");
|
||||
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(),
|
||||
"%foo.bar2.clone");
|
||||
EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar");
|
||||
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2");
|
||||
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone");
|
||||
|
||||
// Test instruction name with a dot.
|
||||
auto foo_baz = HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "foo.baz");
|
||||
EXPECT_EQ(foo_baz->Clone()->name(), "%foo.baz.clone");
|
||||
EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone");
|
||||
|
||||
// Test incrementing a large number after the suffix.
|
||||
auto foo_clone234 = HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "foo.clone234");
|
||||
EXPECT_EQ(foo_clone234->Clone()->name(), "%foo.clone235");
|
||||
EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235");
|
||||
|
||||
// Test a non-numeric string after the cloning suffix.
|
||||
auto foo_clonexyz = HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz");
|
||||
EXPECT_EQ(foo_clonexyz->Clone()->name(), "%foo.clonexyz.clone");
|
||||
EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone");
|
||||
|
||||
// Test a name with multiple appearances of the suffix.
|
||||
auto foo_clone_clone3 = HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3");
|
||||
EXPECT_EQ(foo_clone_clone3->Clone()->name(), "%foo.clone.clone4");
|
||||
EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4");
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, Stringification) {
|
||||
|
||||
@ -87,6 +87,7 @@ HLO_MATCHER(Call);
|
||||
HLO_MATCHER(Ceil);
|
||||
HLO_MATCHER(Clamp);
|
||||
HLO_MATCHER(Concatenate);
|
||||
HLO_MATCHER(Conditional);
|
||||
HLO_MATCHER(Constant);
|
||||
HLO_MATCHER(Convert);
|
||||
HLO_MATCHER(Convolution);
|
||||
|
||||
@ -290,9 +290,16 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
|
||||
tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
|
||||
for (const HloComputationProto& computation_proto : proto.computations()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
|
||||
HloComputation::CreateFromProto(
|
||||
module.get(), computation_proto, &computation_map));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloComputation> computation,
|
||||
HloComputation::CreateFromProto(
|
||||
module.get(), computation_proto, computation_map,
|
||||
/*add_fused_computation=*/
|
||||
[&module](std::unique_ptr<HloComputation> fused_computation) {
|
||||
module->AddComputationInternal(std::move(fused_computation),
|
||||
/*is_entry=*/false,
|
||||
/*uniquify_names=*/false);
|
||||
}));
|
||||
CHECK_NE(computation.get(), nullptr);
|
||||
TF_RET_CHECK(!ContainsKey(computation_map, computation->name()));
|
||||
string computation_name = computation->name();
|
||||
|
||||
@ -52,6 +52,7 @@ namespace xla {
|
||||
V(kBatchNormInference, "batch-norm-inference") \
|
||||
V(kBatchNormTraining, "batch-norm-training") \
|
||||
V(kBitcast, "bitcast") \
|
||||
V(kBitcastConvert, "bitcast-convert") \
|
||||
V(kBroadcast, "broadcast") \
|
||||
V(kCall, "call", kHloOpcodeIsVariadic) \
|
||||
V(kCeil, "ceil") \
|
||||
|
||||
@ -50,7 +50,7 @@ string HloProfilePrinter::ToString(const int64* counters,
|
||||
/*short_name=*/instruction->short_name, instruction->category,
|
||||
counters[instruction->profile_index], instruction->flop_count,
|
||||
instruction->transcendental_count, instruction->bytes_accessed,
|
||||
instruction->seconds);
|
||||
instruction->optimal_seconds);
|
||||
}
|
||||
|
||||
result += builder.ToString();
|
||||
|
||||
@ -41,7 +41,7 @@ class HloProfilePrinter {
|
||||
float flop_count;
|
||||
float transcendental_count;
|
||||
float bytes_accessed;
|
||||
float seconds;
|
||||
float optimal_seconds;
|
||||
|
||||
// The index into the profile counters array for the HloInstruction
|
||||
// corresponding to this HloInstructionInfo.
|
||||
|
||||
@ -62,18 +62,11 @@ bool IsRematerializable(const HloInstruction* instruction) {
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kCrossReplicaSum:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kWhile:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
return !instruction->HasSideEffect();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -114,11 +114,16 @@ HloRunner::~HloRunner() {
|
||||
StatusOr<se::DeviceMemoryBase> HloRunner::Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
|
||||
Shape* result_shape) {
|
||||
Shape* result_shape, bool run_hlo_passes) {
|
||||
if (run_hlo_passes) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, backend().compiler()->RunHloPasses(
|
||||
std::move(module), backend().default_stream_executor()));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
backend().compiler()->Compile(std::move(module),
|
||||
backend().default_stream_executor()));
|
||||
backend().compiler()->RunBackend(std::move(module),
|
||||
backend().default_stream_executor()));
|
||||
|
||||
se::Stream stream(backend().default_stream_executor());
|
||||
stream.Init();
|
||||
@ -193,10 +198,12 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferFromDevice(
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloRunner::ExecuteAndTransfer(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
|
||||
bool run_hlo_passes) {
|
||||
Shape result_shape;
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase device_base,
|
||||
Execute(std::move(module), arguments, &result_shape));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::DeviceMemoryBase device_base,
|
||||
Execute(std::move(module), arguments, &result_shape, run_hlo_passes));
|
||||
return TransferFromDevice(result_shape, device_base);
|
||||
}
|
||||
|
||||
|
||||
@ -65,17 +65,20 @@ class HloRunner {
|
||||
// Executes the given module with given literals as input and returns the
|
||||
// result as a Literal. The LiteralPtr type accepts Literal* or
|
||||
// std::unique_ptr<Literal>.
|
||||
// If run_hlo_passes is true, the module will be executed without Hlo
|
||||
// optimization.
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::ArraySlice<LiteralPtr> literals);
|
||||
const tensorflow::gtl::ArraySlice<LiteralPtr> literals,
|
||||
bool run_hlo_passes = true);
|
||||
|
||||
// Executes the given module and returns a global data handle.
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments,
|
||||
Shape* result_shape);
|
||||
Shape* result_shape, bool run_hlo_passes = true);
|
||||
|
||||
// Transfers the given literal to the device and returns the data handle.
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase> TransferToDevice(
|
||||
@ -90,7 +93,8 @@ class HloRunner {
|
||||
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments);
|
||||
arguments,
|
||||
bool run_hlo_passes = true);
|
||||
|
||||
// If backend is not created in the constructor, creates and returns the
|
||||
// default backend. If creation fails, crashes the program.
|
||||
@ -112,14 +116,15 @@ class HloRunner {
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::ArraySlice<LiteralPtr> literals) {
|
||||
const tensorflow::gtl::ArraySlice<LiteralPtr> literals,
|
||||
bool run_hlo_passes) {
|
||||
std::vector<perftools::gputools::DeviceMemoryBase> arguments;
|
||||
for (const auto& literal : literals) {
|
||||
TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument,
|
||||
TransferToDevice(*literal));
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return ExecuteAndTransfer(std::move(module), arguments);
|
||||
return ExecuteAndTransfer(std::move(module), arguments, run_hlo_passes);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -160,7 +160,59 @@ bool HloSharding::HasUniqueDevice() const {
|
||||
}
|
||||
}
|
||||
|
||||
Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
|
||||
if (!ShapeUtil::IsTuple(shape)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
StrCat("Sharding is tuple-shaped but validation shape is not."));
|
||||
}
|
||||
// The easiest way to get the number of elements in a nested tuple is just to
|
||||
// create a shape tree. We could call GetAsShapeTree, but that will try and
|
||||
// apply our tuple_shardings_ to the shape tree, and that might cause a crash
|
||||
// at this point as we haven't validated them.
|
||||
ShapeTree<bool> bool_shape_tree(shape, false);
|
||||
int64 num_leaves =
|
||||
std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end());
|
||||
if (num_leaves != tuple_elements_.size()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
StrCat("Validation tuple shape has ", num_leaves,
|
||||
" leaf elements, but this sharding contains ",
|
||||
tuple_elements_.size(), " elements."));
|
||||
}
|
||||
|
||||
// Now we've validated the number of tuple elements, it's safe to request a
|
||||
// shape tree.
|
||||
ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
|
||||
for (const auto& index_to_sharding : shape_tree.leaves()) {
|
||||
Status status = index_to_sharding.second.ValidateNonTuple(
|
||||
ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
|
||||
if (!status.ok()) {
|
||||
tensorflow::errors::AppendToMessage(
|
||||
&status, StrCat("Note: While validating sharding tuple element ",
|
||||
index_to_sharding.first.ToString(), " which is ",
|
||||
index_to_sharding.second.ToString()));
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
|
||||
Status status = IsTuple() ? ValidateTuple(shape, num_devices)
|
||||
: ValidateNonTuple(shape, num_devices);
|
||||
if (!status.ok()) {
|
||||
tensorflow::errors::AppendToMessage(
|
||||
&status, StrCat("Note: While validating sharding ", ToString(),
|
||||
" against shape ", ShapeUtil::HumanString(shape)));
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status HloSharding::ValidateNonTuple(const Shape& shape,
|
||||
int64 num_devices) const {
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
StrCat("Validation shape is a tuple but sharding is not."));
|
||||
}
|
||||
if (replicated_) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -174,13 +226,11 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
|
||||
// Don't overwrite a bad status, so we report the first error.
|
||||
if (status.ok()) {
|
||||
if (core >= num_devices) {
|
||||
status =
|
||||
tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
|
||||
"core ", core, " > ", num_devices, " in tile assignment"));
|
||||
status = tensorflow::errors::InvalidArgument(StrCat(
|
||||
"core ", core, " > ", num_devices, " in tile assignment"));
|
||||
} else if (seen_cores.count(core) != 0) {
|
||||
status =
|
||||
tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
|
||||
"core ", core, " is not unique in tile assignment"));
|
||||
status = tensorflow::errors::InvalidArgument(
|
||||
StrCat("core ", core, " is not unique in tile assignment"));
|
||||
}
|
||||
}
|
||||
seen_cores.insert(core);
|
||||
@ -214,9 +264,9 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
|
||||
auto tile_dim = tile_shape_.dimensions(i);
|
||||
auto shape_dim = shape.dimensions(i);
|
||||
if (tile_dim > shape_dim) {
|
||||
return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
|
||||
"Tile is larger than input shape (dimension ", i, ", ", tile_dim,
|
||||
" > ", shape_dim));
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
StrCat("Tile is larger than input shape (dimension ", i, ", ",
|
||||
tile_dim, " > ", shape_dim));
|
||||
}
|
||||
}
|
||||
|
||||
@ -226,10 +276,10 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
|
||||
int64 expected_dim =
|
||||
CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i));
|
||||
if (tile_assignment_.dimensions()[i] != expected_dim) {
|
||||
return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
|
||||
"Tile assignment tensor has incorrect shape. Dimension ", i,
|
||||
" expected ", expected_dim, " but got ",
|
||||
tile_assignment_.dimensions()[i]));
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
StrCat("Tile assignment tensor has incorrect shape. Dimension ", i,
|
||||
" expected ", expected_dim, " but got ",
|
||||
tile_assignment_.dimensions()[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -222,6 +222,11 @@ class HloSharding {
|
||||
tile_assignment_({0}),
|
||||
tuple_elements_(tuple_shardings) {}
|
||||
|
||||
// Internal helper to validate a tuple sharding.
|
||||
Status ValidateTuple(const Shape& shape, int64 num_devices) const;
|
||||
// Internal helper to validate a non-tuple (leaf) sharding.
|
||||
Status ValidateNonTuple(const Shape& shape, int64 num_devices) const;
|
||||
|
||||
bool replicated_;
|
||||
bool maximal_;
|
||||
bool tuple_;
|
||||
|
||||
@ -145,11 +145,13 @@ TEST_F(HloShardingTest, NestedTuple) {
|
||||
ShapeUtil::MakeShape(F32, {4, 6}),
|
||||
});
|
||||
|
||||
HloSharding tiled_sharding = HloSharding::Tile(
|
||||
ShapeUtil::MakeShape(F32, {4, 3}), Array<int64>({{0, 1}}));
|
||||
OpSharding proto;
|
||||
proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
|
||||
*proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
|
||||
*proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
|
||||
*proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto();
|
||||
*proto.add_tuple_shardings() = tiled_sharding.ToProto();
|
||||
HloSharding tuple_sharding =
|
||||
HloSharding::FromProto(proto).ConsumeValueOrDie();
|
||||
|
||||
@ -157,7 +159,15 @@ TEST_F(HloShardingTest, NestedTuple) {
|
||||
tuple_sharding.GetAsShapeTree(nested_tuple_shape);
|
||||
EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
|
||||
EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
|
||||
EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1));
|
||||
EXPECT_EQ(shape_tree.element({2}), tiled_sharding);
|
||||
|
||||
EXPECT_IS_OK(tuple_sharding.Validate(nested_tuple_shape, /*num_devices=*/5));
|
||||
// Test should fail because tuple element count does not match.
|
||||
EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeTupleShape({}),
|
||||
/*num_devices=*/5));
|
||||
// Test should fail because the input type is not a tuple.
|
||||
EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeShape(F32, {}),
|
||||
/*num_devices=*/5));
|
||||
}
|
||||
|
||||
TEST_F(HloShardingTest, Hash) {
|
||||
|
||||
@ -59,15 +59,17 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
}
|
||||
|
||||
Status HandleConvert(HloInstruction* convert) override {
|
||||
if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) {
|
||||
TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape()))
|
||||
<< "Unsupported complex->real kConvert";
|
||||
}
|
||||
return CheckShape(convert, ShapeInference::InferConvertShape(
|
||||
convert->operand(0)->shape(),
|
||||
convert->shape().element_type()));
|
||||
}
|
||||
|
||||
Status HandleBitcastConvert(HloInstruction* convert) override {
|
||||
return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
|
||||
convert->operand(0)->shape(),
|
||||
convert->shape().element_type()));
|
||||
}
|
||||
|
||||
Status HandleCopy(HloInstruction* copy) override {
|
||||
return CheckUnaryShape(copy);
|
||||
}
|
||||
@ -263,6 +265,15 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
xla_while->while_body()->ComputeProgramShape().result());
|
||||
}
|
||||
|
||||
Status HandleConditional(HloInstruction* conditional) override {
|
||||
TF_RETURN_IF_ERROR(CheckShape(
|
||||
conditional,
|
||||
conditional->true_computation()->ComputeProgramShape().result()));
|
||||
return CheckShape(
|
||||
conditional,
|
||||
conditional->false_computation()->ComputeProgramShape().result());
|
||||
}
|
||||
|
||||
Status HandlePad(HloInstruction* pad) override {
|
||||
return CheckShape(pad,
|
||||
ShapeInference::InferPadShape(pad->operand(0)->shape(),
|
||||
@ -571,7 +582,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
// or ComputationLowerer::Visit()
|
||||
TF_RET_CHECK(instruction->dimensions().size() ==
|
||||
ShapeUtil::Rank(instruction->operand(0)->shape()))
|
||||
<< "Broadcast HLO has invalid number of dimensions.";
|
||||
<< "Broadcast HLO has invalid number of dimensions.";
|
||||
} else if (instruction->opcode() == HloOpcode::kWhile) {
|
||||
auto* while_cond = instruction->while_condition();
|
||||
auto* while_body = instruction->while_body();
|
||||
|
||||
@ -33,7 +33,9 @@ namespace xla {
|
||||
switch (instruction.opcode()) {
|
||||
// Cheap instructions.
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kBitcastConvert:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClamp:
|
||||
@ -53,15 +55,14 @@ namespace xla {
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kReal:
|
||||
@ -88,9 +89,9 @@ namespace xla {
|
||||
|
||||
// Expensive instructions.
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kConvolution:
|
||||
@ -104,19 +105,19 @@ namespace xla {
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReduceWindow:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kTanh:
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kWhile:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -69,11 +69,19 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
return pipeline.Run(hlo_module).status();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::Compile(
|
||||
StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
se::StreamExecutor* /*stream_exec*/) {
|
||||
VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
|
||||
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
|
||||
return std::move(hlo_module);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec) {
|
||||
TF_RET_CHECK(stream_exec != nullptr);
|
||||
|
||||
VLOG(1) << "Generate graph " << hlo_module->name();
|
||||
VLOG(1) << "Run backend " << hlo_module->name();
|
||||
|
||||
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
|
||||
|
||||
|
||||
@ -43,8 +43,12 @@ class InterpreterCompiler : public Compiler {
|
||||
InterpreterCompiler() {}
|
||||
~InterpreterCompiler() override {}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
std::unique_ptr<HloModule> hlo_modules,
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
|
||||
@ -103,7 +103,7 @@ namespace {
|
||||
|
||||
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
|
||||
// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
|
||||
// where 'user' is a user of an alias of 'intruction' at 'index', and
|
||||
// where 'user' is a user of an alias of 'instruction' at 'index', and
|
||||
// 'operand_index' is the operand index at which the alias appears in the
|
||||
// operand list of 'user'.
|
||||
std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
|
||||
@ -243,6 +243,31 @@ bool CanShareOperandBufferWithUser(
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && operand_indices[0] == 0;
|
||||
}
|
||||
if (user->opcode() == HloOpcode::kCall) {
|
||||
// TODO(b/62548313): Remove when buffer assignment is module scoped and
|
||||
// does not assign buffers to calls.
|
||||
// Find called computation parameter associated with 'operand'.
|
||||
const std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
if (operand_indices.size() > 1) {
|
||||
return false;
|
||||
}
|
||||
CHECK_EQ(1, operand_indices.size());
|
||||
auto* param = user->to_apply()->parameter_instruction(operand_indices[0]);
|
||||
// Get all uses of 'operand' at 'index' in called computation.
|
||||
auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index,
|
||||
points_to_analysis);
|
||||
|
||||
// Return true iff:
|
||||
// *) There exists exactly one use of 'operand' in called computation.
|
||||
// *) The unique use is by the root instruction of called computation.
|
||||
// (Note: we check the root of the called computation, because the
|
||||
// root result buffer is required to alias with the Call result buffer).
|
||||
// *) The root instruction of the called computation is element-wise on
|
||||
// 'operand'.
|
||||
auto* callee_root = user->to_apply()->root_instruction();
|
||||
return param_uses.size() == 1 && param_uses[0].first == callee_root &&
|
||||
callee_root->IsElementwiseOnOperand(param_uses[0].second);
|
||||
}
|
||||
// Check if 'user' is element-wise.
|
||||
return user->IsElementwise();
|
||||
}
|
||||
@ -322,6 +347,31 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand,
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && operand_indices[0] == 0;
|
||||
}
|
||||
if (user->opcode() == HloOpcode::kCall) {
|
||||
// Get all uses of value defined by 'operand' at 'operand_index'.
|
||||
const auto& uses =
|
||||
dataflow.GetValueDefinedAt(operand, operand_index).uses();
|
||||
// Return true iff:
|
||||
// *) There exists two uses of 'operand'.
|
||||
// *) One use is by 'user' (caller).
|
||||
// *) One use is by root instruction of called computation (callee root).
|
||||
// (Note: we check the root of the called computation, because the
|
||||
// root result buffer is required to alias with the Call result buffer).
|
||||
// *) The root instruction of the called computation is element-wise on
|
||||
// 'operand'.
|
||||
const bool found_caller_use =
|
||||
std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) {
|
||||
return use.instruction == user;
|
||||
}) != uses.end();
|
||||
auto* callee_root = user->to_apply()->root_instruction();
|
||||
const bool found_elementwise_callee_use =
|
||||
std::find_if(
|
||||
uses.begin(), uses.end(), [callee_root](const HloUse& use) {
|
||||
return use.instruction == callee_root &&
|
||||
callee_root->IsElementwiseOnOperand(use.operand_number);
|
||||
}) != uses.end();
|
||||
return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
|
||||
}
|
||||
// Check if 'user' is element-wise.
|
||||
return user->IsElementwise();
|
||||
}
|
||||
|
||||
@ -415,5 +415,44 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
||||
CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_));
|
||||
}
|
||||
|
||||
// Tests that Call can alias operand buffer if the only use of the operand
|
||||
// in the called computation is an elementwise instruction.
|
||||
TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {8});
|
||||
// Build sub-computation with fusion root.
|
||||
auto sub_builder = HloComputation::Builder(TestName() + "_sub");
|
||||
auto sub_param = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "sub_param"));
|
||||
auto one = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto ones = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(shape, one, {1}));
|
||||
auto add = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
|
||||
|
||||
module_ = CreateNewModule();
|
||||
auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
|
||||
sub_computation->CreateFusionInstruction({add, ones},
|
||||
HloInstruction::FusionKind::kLoop);
|
||||
|
||||
// Build entry-computation with kCall which calls 'sub_computation'.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param"));
|
||||
auto reverse =
|
||||
builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
|
||||
auto call = builder.AddInstruction(
|
||||
HloInstruction::CreateCall(shape, {reverse}, sub_computation));
|
||||
computation_ = module_->AddEntryComputation(builder.Build());
|
||||
|
||||
RunAnalysis();
|
||||
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {},
|
||||
*points_to_analysis_));
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {},
|
||||
*dataflow_analysis_));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
@ -27,8 +27,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
||||
"Model partitioning not implemented for the CPU/GPU compilers!");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
modules[i], RunHloPasses(std::move(modules[i]), stream_execs[i][0]));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
Compile(std::move(modules[i]), stream_execs[i][0]));
|
||||
RunBackend(std::move(modules[i]), stream_execs[i][0]));
|
||||
result.push_back(std::move(executable));
|
||||
}
|
||||
|
||||
|
||||
@ -58,10 +58,14 @@ class LLVMCompiler : public Compiler {
|
||||
void RemovePostOptimizationHook() { user_post_optimization_hook_ = nullptr; }
|
||||
|
||||
// Bring in
|
||||
// StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
// std::unique_ptr<HloModule> module,
|
||||
// perftools::gputools::StreamExecutor* executor)
|
||||
using Compiler::Compile;
|
||||
// StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
// std::unique_ptr<HloModule> module,
|
||||
// perftools::gputools::StreamExecutor* stream_exec)
|
||||
// StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
// std::unique_ptr<HloModule> module,
|
||||
// perftools::gputools::StreamExecutor* stream_exec)
|
||||
using Compiler::RunBackend;
|
||||
using Compiler::RunHloPasses;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> modules,
|
||||
|
||||
@ -40,11 +40,24 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
|
||||
inline bool CanEmitFusedDynamicUpdateSliceInPlace(
|
||||
HloInstruction* fusion, const BufferAssignment& assignment) {
|
||||
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
|
||||
return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop &&
|
||||
fusion->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice &&
|
||||
CanUpdateDynamicSliceInPlace(fusion->fused_expression_root(),
|
||||
assignment);
|
||||
HloInstruction* fused_root = fusion->fused_expression_root();
|
||||
if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice ||
|
||||
fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) {
|
||||
return false;
|
||||
}
|
||||
// Walk DynamicUpdateSlice operand(0) to fused parameter and get its
|
||||
// associated operand. See if it shares an allocation with this operand.
|
||||
HloInstruction* fusion_operand;
|
||||
ShapeIndex index;
|
||||
std::tie(fusion_operand, index) =
|
||||
fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
|
||||
if (fusion_operand->opcode() != HloOpcode::kParameter) {
|
||||
return false;
|
||||
}
|
||||
auto* operand = fusion->operand(fusion_operand->parameter_number());
|
||||
return assignment.HasAllocationAt(operand, index) &&
|
||||
assignment.HasAllocationAt(fusion, {}) &&
|
||||
assignment.SharesSliceAtIndex(fusion, {}, operand, index);
|
||||
}
|
||||
|
||||
// Emits IR for running the given dynamic-update-slice op in-place -- that is,
|
||||
|
||||
@ -430,9 +430,12 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
/*include_unreachable_instructions=*/
|
||||
true));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, backend->compiler()->RunHloPasses(std::move(module), executor));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
backend->compiler()->Compile(std::move(module), executor));
|
||||
backend->compiler()->RunBackend(std::move(module), executor));
|
||||
|
||||
if (!other_directory_path.empty()) {
|
||||
executable->set_session_module(std::move(session_module));
|
||||
@ -1361,6 +1364,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
|
||||
handle_status =
|
||||
computation->AddConvertInstruction(arg->convert_request());
|
||||
break;
|
||||
case OpRequest::kBitcastConvertRequest:
|
||||
handle_status = computation->AddBitcastConvertInstruction(
|
||||
arg->bitcast_convert_request());
|
||||
break;
|
||||
case OpRequest::kConvolveRequest:
|
||||
handle_status =
|
||||
computation->AddConvolveInstruction(arg->convolve_request());
|
||||
|
||||
@ -441,6 +441,14 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
|
||||
const Shape& operand_shape, PrimitiveType new_element_type) {
|
||||
auto old_element_type = operand_shape.element_type();
|
||||
if (primitive_util::IsComplexType(old_element_type) &&
|
||||
!primitive_util::IsComplexType(new_element_type)) {
|
||||
return Unimplemented(
|
||||
"Unsupported conversion from complex to real type: %s => %s",
|
||||
ShapeUtil::HumanString(operand_shape).c_str(),
|
||||
PrimitiveType_Name(new_element_type).c_str());
|
||||
}
|
||||
if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
|
||||
// Note: we may want to support tuple conversions via this operation in the
|
||||
// future, by recursing into the tuple elements to check all sub-conversions
|
||||
@ -454,6 +462,36 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
|
||||
const Shape& operand_shape, PrimitiveType new_element_type) {
|
||||
auto old_element_type = operand_shape.element_type();
|
||||
if (primitive_util::IsComplexType(old_element_type) !=
|
||||
primitive_util::IsComplexType(new_element_type)) {
|
||||
return Unimplemented(
|
||||
"Unsupported conversion between real and complex types: %s => %s",
|
||||
ShapeUtil::HumanString(operand_shape).c_str(),
|
||||
PrimitiveType_Name(new_element_type).c_str());
|
||||
}
|
||||
if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
|
||||
// Note: we may want to support tuple conversions via this operation in the
|
||||
// future, by recursing into the tuple elements to check all sub-conversions
|
||||
// are valid. For now we just reject them, though.
|
||||
return InvalidArgument(
|
||||
"cannot convert from or to tuple type; requested conversion: %s => %s",
|
||||
ShapeUtil::HumanString(operand_shape).c_str(),
|
||||
PrimitiveType_Name(new_element_type).c_str());
|
||||
}
|
||||
if (primitive_util::BitWidth(old_element_type) !=
|
||||
primitive_util::BitWidth(new_element_type)) {
|
||||
return InvalidArgument(
|
||||
"cannot bitcast types with different bit-widths: %s => %s",
|
||||
PrimitiveType_Name(old_element_type).c_str(),
|
||||
PrimitiveType_Name(new_element_type).c_str());
|
||||
}
|
||||
|
||||
return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
|
||||
const Shape& operand_shape, const int exponent_bits,
|
||||
const int mantissa_bits) {
|
||||
|
||||
@ -204,6 +204,13 @@ class ShapeInference {
|
||||
static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
|
||||
PrimitiveType new_element_type);
|
||||
|
||||
// Helper that validates the given operand shape can be bitcast converted to
|
||||
// the target output_shape via a bitcast convert instruction -- the
|
||||
// requirement is that the shape is identical except for the element type and
|
||||
// the element types have identical bit-widths.
|
||||
static StatusOr<Shape> InferBitcastConvertShape(
|
||||
const Shape& operand_shape, PrimitiveType new_element_type);
|
||||
|
||||
// Helper that validates the input data type for a reduce-precision operation,
|
||||
// and returns the result shape.
|
||||
static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,
|
||||
|
||||
@ -994,6 +994,32 @@ StatusOr<ComputationDataHandle> UserComputation::AddConvertInstruction(
|
||||
return handle;
|
||||
}
|
||||
|
||||
StatusOr<ComputationDataHandle> UserComputation::AddBitcastConvertInstruction(
|
||||
const ConvertRequest& convert_request) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
|
||||
LookUpRequest(convert_request.operand()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape(
|
||||
operand->output_shape(),
|
||||
convert_request.new_element_type()));
|
||||
|
||||
ComputationDataHandle handle = CreateComputationDataHandle();
|
||||
|
||||
OperationRequest& request =
|
||||
(*session_computation_.mutable_requests())[handle.handle()];
|
||||
*request.mutable_output_handle() = handle;
|
||||
*request.mutable_output_shape() = new_shape;
|
||||
*request.mutable_request()->mutable_bitcast_convert_request() =
|
||||
convert_request;
|
||||
|
||||
VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal()
|
||||
<< "), data handle " << handle.handle() << ": "
|
||||
<< convert_request.ShortDebugString();
|
||||
return handle;
|
||||
}
|
||||
|
||||
StatusOr<ComputationDataHandle> UserComputation::AddReducePrecisionInstruction(
|
||||
const ReducePrecisionRequest& reduce_precision_request) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
@ -2370,6 +2396,13 @@ static void ForEachOperand(
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBitcastConvertRequest: {
|
||||
const ConvertRequest& convert_request =
|
||||
request.request().bitcast_convert_request();
|
||||
apply(convert_request.operand());
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kWhileRequest: {
|
||||
const WhileRequest& while_request = request.request().while_request();
|
||||
apply(while_request.init());
|
||||
@ -2954,6 +2987,15 @@ void ComputationLowerer::Visit(
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBitcastConvertRequest: {
|
||||
const ConvertRequest& convert_request =
|
||||
request.request().bitcast_convert_request();
|
||||
HloInstruction* operand = lookup_instruction(convert_request.operand());
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert(
|
||||
request.output_shape(), operand));
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kWhileRequest: {
|
||||
const WhileRequest& while_request = request.request().while_request();
|
||||
CHECK_EQ(2, request.embedded_computation_versions_size());
|
||||
@ -2978,6 +3020,25 @@ void ComputationLowerer::Visit(
|
||||
HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs());
|
||||
HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs());
|
||||
auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop());
|
||||
|
||||
if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) {
|
||||
if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
|
||||
// lhs side is being implicitly broadcast. Change to explicit.
|
||||
lhs =
|
||||
ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
|
||||
rhs =
|
||||
ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) {
|
||||
ehs =
|
||||
ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape());
|
||||
}
|
||||
}
|
||||
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateTernary(
|
||||
request.output_shape(), hlo_opcode, lhs, rhs, ehs));
|
||||
break;
|
||||
@ -3137,7 +3198,7 @@ void ComputationLowerer::Visit(
|
||||
LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
|
||||
}
|
||||
(*instructions)[handle.handle()] = hlo_instruction;
|
||||
}
|
||||
} // NOLINT(readability/fn_size)
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class UserComputation {
|
||||
|
||||
// Enqueues a pad instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddPadInstruction(
|
||||
const PadRequest& parameter_request);
|
||||
const PadRequest& pad_request);
|
||||
|
||||
// Enqueues a tracing instruction onto this user computation.
|
||||
// Returns an error status if the operand cannot be resolved.
|
||||
@ -105,7 +105,7 @@ class UserComputation {
|
||||
// Enqueues a ternary instruction onto this user computation.
|
||||
// Returns an error status if the operand indices are out of bounds.
|
||||
StatusOr<ComputationDataHandle> AddTernaryInstruction(
|
||||
const TernaryOpRequest& request);
|
||||
const TernaryOpRequest& ternary_request);
|
||||
|
||||
// Enqueues a variadic instruction onto this user computation.
|
||||
// Returns an error status if the operand indices are out of bounds.
|
||||
@ -179,26 +179,30 @@ class UserComputation {
|
||||
|
||||
// Enqueues a concatenate instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddConcatenateInstruction(
|
||||
const ConcatenateRequest& slice_request);
|
||||
const ConcatenateRequest& concatenate_request);
|
||||
|
||||
// Enqueues a convert instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddConvertInstruction(
|
||||
const ConvertRequest& convert_request);
|
||||
|
||||
// Enqueues a bitcast element instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddBitcastConvertInstruction(
|
||||
const ConvertRequest& convert_request);
|
||||
|
||||
// Enqueues a reduce instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddReduceInstruction(
|
||||
const ReduceRequest& reduce_request,
|
||||
const UserComputation& reduction_computation);
|
||||
const UserComputation& to_apply_computation);
|
||||
|
||||
// Enqueues a windowed reduce instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddReduceWindowInstruction(
|
||||
const ReduceWindowRequest& reduce_window_request,
|
||||
const UserComputation& reduction_computation);
|
||||
const UserComputation& to_apply_computation);
|
||||
|
||||
// Enqueues a select-and-scatter instruction onto this user
|
||||
// computation.
|
||||
StatusOr<ComputationDataHandle> AddSelectAndScatterInstruction(
|
||||
const SelectAndScatterRequest& scatter_to_selected_window_element_request,
|
||||
const SelectAndScatterRequest& select_and_scatter_request,
|
||||
const UserComputation& select_computation,
|
||||
const UserComputation& scatter_computation);
|
||||
|
||||
|
||||
@ -403,6 +403,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
|
||||
// Compute the shape of the while op after we remove the dead indices.
|
||||
std::vector<Shape> new_while_tuple_elem_shapes;
|
||||
new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size());
|
||||
for (int64 old_idx : new_to_old_tuple_idx) {
|
||||
new_while_tuple_elem_shapes.push_back(
|
||||
while_init->shape().tuple_shapes(old_idx));
|
||||
@ -469,6 +470,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
|
||||
while_body_replacements = make_while_computation_replacements(while_body);
|
||||
std::vector<HloInstruction*> new_while_body_root_elems;
|
||||
new_while_body_root_elems.reserve(new_to_old_tuple_idx.size());
|
||||
for (int64 old_idx : new_to_old_tuple_idx) {
|
||||
new_while_body_root_elems.push_back(
|
||||
while_body_root->mutable_operand(old_idx));
|
||||
@ -483,6 +485,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
// clean this up in the common case where while_init is a tuple op. (It's
|
||||
// definitely tuple-shaped, but it's not necessarily a tuple op.)
|
||||
std::vector<HloInstruction*> new_while_init_elems;
|
||||
new_while_init_elems.reserve(new_to_old_tuple_idx.size());
|
||||
for (int64 old_idx : new_to_old_tuple_idx) {
|
||||
new_while_init_elems.push_back(
|
||||
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
|
||||
@ -382,7 +382,6 @@ xla_test(
|
||||
name = "params_test",
|
||||
srcs = ["params_test.cc"],
|
||||
shard_count = 30,
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -512,6 +511,7 @@ xla_test(
|
||||
name = "array_elementwise_ops_test",
|
||||
srcs = ["array_elementwise_ops_test.cc"],
|
||||
shard_count = 25,
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
@ -770,6 +770,38 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "bfloat16_test",
|
||||
srcs = ["bfloat16_test.cc"],
|
||||
shard_count = 40,
|
||||
deps = [
|
||||
":test_utils",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:computation",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "slice_test",
|
||||
srcs = ["slice_test.cc"],
|
||||
@ -1230,6 +1262,23 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "bitcast_convert_test",
|
||||
srcs = ["bitcast_convert_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "compilation_cache_test",
|
||||
srcs = ["compilation_cache_test.cc"],
|
||||
|
||||
75
tensorflow/compiler/xla/tests/bfloat16_test.cc
Normal file
75
tensorflow/compiler/xla/tests/bfloat16_test.cc
Normal file
@ -0,0 +1,75 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/array4d.h"
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/reference_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class Bfloat16Test : public ClientLibraryTestBase {
|
||||
protected:
|
||||
const ErrorSpec error_spec_{0.001, 0.001};
|
||||
};
|
||||
|
||||
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
|
||||
DISABLED_ON_CPU(ScalarOperation)))) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
|
||||
auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
|
||||
builder.Add(x, y);
|
||||
|
||||
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(3.0f), {},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
|
||||
DISABLED_ON_CPU(NegateScalarF16)))) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
|
||||
|
||||
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
141
tensorflow/compiler/xla/tests/bitcast_convert_test.cc
Normal file
141
tensorflow/compiler/xla/tests/bitcast_convert_test.cc
Normal file
@ -0,0 +1,141 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class BitcastConvertTest : public ClientLibraryTestBase {
|
||||
public:
|
||||
explicit BitcastConvertTest(perftools::gputools::Platform* platform = nullptr)
|
||||
: ClientLibraryTestBase(platform) {
|
||||
mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
|
||||
mutable_debug_options()->add_xla_disable_hlo_passes("inline");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<int32>({42, 64});
|
||||
builder.BitcastConvertType(a, S32);
|
||||
|
||||
std::vector<int32> expected = {42, 64};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<float>({42.0f, 64.0f});
|
||||
builder.BitcastConvertType(a, F32);
|
||||
|
||||
std::vector<float> expected = {42.0f, 64.0f};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a =
|
||||
builder.ConstantR1<int32>({0, static_cast<int32>(0x80000000), 0x3F800000,
|
||||
static_cast<int32>(0xBF800000), 0x3F000000,
|
||||
static_cast<int32>(0xBF000000)});
|
||||
builder.BitcastConvertType(a, F32);
|
||||
|
||||
std::vector<float> expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<int32>({});
|
||||
builder.BitcastConvertType(a, F32);
|
||||
|
||||
std::vector<float> expected = {};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<float>({42.6, 64.4});
|
||||
builder.BitcastConvertType(a, S32);
|
||||
|
||||
std::vector<int32> expected = {0x422a6666, 0x4280cccd};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(BitcastConvertTest, ConvertS32Extremes) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<int32>(
|
||||
{std::numeric_limits<int32>::min(), std::numeric_limits<int32>::max()});
|
||||
builder.BitcastConvertType(a, F32);
|
||||
|
||||
std::vector<float> expected = {-0.0f, NAN};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0, 0));
|
||||
}
|
||||
|
||||
TEST_F(BitcastConvertTest, ConvertMapToS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto b = builder.CreateSubBuilder("convert");
|
||||
auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in");
|
||||
b->BitcastConvertType(param, S32);
|
||||
auto a = builder.ConstantR1<float>({42.0f, 64.0f});
|
||||
builder.Map({a}, b->BuildAndNoteError(), {0});
|
||||
|
||||
std::vector<int32> expected = {0x42280000, 0x42800000};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(BitcastConvertTest, ConvertMapToF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto b = builder.CreateSubBuilder("convert");
|
||||
auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in");
|
||||
b->BitcastConvertType(param, F32);
|
||||
auto a = builder.ConstantR1<int32>({0x42280000, 0x42800000});
|
||||
builder.Map({a}, b->BuildAndNoteError(), {0});
|
||||
|
||||
std::vector<float> expected = {42.0f, 64.0f};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
// Regression test for b/31758660. When ReshapeMover transforms
|
||||
// input -> reshape -> convert
|
||||
// to
|
||||
// input -> convert -> reshape
|
||||
// the new convert should have the same element type as the old convert.
|
||||
TEST_F(BitcastConvertTest, ConvertReshape) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto input = builder.ConstantR1<int32>({0x42280000});
|
||||
auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
|
||||
builder.BitcastConvertType(reshape, F32);
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 42.0f, {});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
@ -333,6 +333,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
||||
static_assert(std::is_same<NativeT, float>::value ||
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
|
||||
@ -19,8 +19,11 @@ namespace xla {
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CodegenTestBase::CompileToExecutable(
|
||||
std::unique_ptr<HloModule> hlo_module) {
|
||||
return backend().compiler()->Compile(std::move(hlo_module),
|
||||
backend().default_stream_executor());
|
||||
TF_ASSIGN_OR_RETURN(hlo_module, backend().compiler()->RunHloPasses(
|
||||
std::move(hlo_module),
|
||||
backend().default_stream_executor()));
|
||||
return backend().compiler()->RunBackend(std::move(hlo_module),
|
||||
backend().default_stream_executor());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<AotCompilationResult>>
|
||||
|
||||
@ -458,6 +458,54 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) {
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
// Test fixture to run convolution tests with and without convolution
|
||||
// canonicalization enabled.
|
||||
class ConvolveWithAndWithoutCanonicalization
|
||||
: public ConvolutionTest,
|
||||
public ::testing::WithParamInterface<bool> {};
|
||||
|
||||
XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
|
||||
DISABLED_ON_GPU(Convolve2D_NoSpatialDims)) {
|
||||
if (GetParam()) {
|
||||
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
|
||||
"convolution-canonicalization");
|
||||
}
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10});
|
||||
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_feature_dimension(0);
|
||||
dnums.set_input_batch_dimension(1);
|
||||
dnums.set_kernel_input_feature_dimension(0);
|
||||
dnums.set_kernel_output_feature_dimension(1);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.set_output_feature_dimension(1);
|
||||
auto conv = builder.ConvWithGeneralDimensions(input, filter, {},
|
||||
Padding::kValid, dnums);
|
||||
|
||||
Array2D<float> param0(4, 29);
|
||||
param0.FillUnique();
|
||||
|
||||
Array2D<float> param1(4, 10);
|
||||
param1.FillUnique();
|
||||
|
||||
Array2D<float> expected_result(29, 10);
|
||||
expected_result.Fill(0);
|
||||
|
||||
ComputeAndCompare(
|
||||
&builder, conv,
|
||||
{*Literal::CreateFromArray(param0), *Literal::CreateFromArray(param1)},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation,
|
||||
ConvolveWithAndWithoutCanonicalization,
|
||||
::testing::Values(true, false));
|
||||
|
||||
struct Convolve1DTestParam {
|
||||
int64 input_feature;
|
||||
int64 output_feature;
|
||||
|
||||
@ -340,6 +340,9 @@ class NearComparator {
|
||||
multi_index_.resize(expected.shape().dimensions_size(), 0);
|
||||
|
||||
switch (expected.shape().element_type()) {
|
||||
case BF16:
|
||||
ExpectLiteralsNear<bfloat16>(expected, actual, 0);
|
||||
break;
|
||||
case F32:
|
||||
ExpectLiteralsNear<float>(expected, actual, 0);
|
||||
break;
|
||||
@ -525,6 +528,13 @@ void NearComparator::ExpectNear<complex64>(complex64 expected, complex64 actual,
|
||||
<< message;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
|
||||
bfloat16 actual) {
|
||||
return ExpectValuesNear(static_cast<float>(expected),
|
||||
static_cast<float>(actual));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
|
||||
|
||||
@ -73,8 +73,8 @@ class LLVMCompilerTest : public ::testing::Test {
|
||||
compiler->SetPostOptimizationHook(post_opt_hook);
|
||||
|
||||
ASSERT_TRUE(compiler
|
||||
->Compile(std::move(hlo_module),
|
||||
backend_->default_stream_executor())
|
||||
->RunBackend(std::move(hlo_module),
|
||||
backend_->default_stream_executor())
|
||||
.ok());
|
||||
|
||||
// Test that hooks were called.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user