Merge branch 'master' into doc-patch-batch-normalization

This commit is contained in:
Brad Huang 2019-04-12 10:50:33 -07:00 committed by GitHub
commit 6004a65a77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1139 changed files with 41179 additions and 16719 deletions

View File

@ -25,7 +25,7 @@ networks research. The system is general enough to be applicable in a wide
variety of other domains, as well.
TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards
compatible API's for C++, Go, Java, JavaScript and Swift.
compatible API's for C++, Go, Java, JavaScript, and Swift.
Keep up to date with release announcements and security updates by
subscribing to

View File

@ -43,47 +43,37 @@ remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "4b90786009fa8df25230442244bad2832ba8d6bc4987f68150a7de59c8827e90",
strip_prefix = "rules_apple-0.14.0",
urls = ["https://github.com/bazelbuild/rules_apple/archive/0.14.0.tar.gz"],
)
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"],
)
http_archive(
name = "bazel_skylib",
sha256 = "2c62d8cd4ab1e65c08647eb4afe38f51591f43f7f0885e7769832fa137633dcb",
strip_prefix = "bazel-skylib-0.7.0",
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.7.0.tar.gz"],
)
sha256 = "8f32e2839fba28d549e1670dbed83606dd339a9f7489118e481814d61738270f",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.14.0/rules_apple.0.14.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "835663c4bb02f4bf01dce8a2a176df7fa682dbb867d3698ae12258c1628bb8f0",
strip_prefix = "apple_support-0.5.0",
urls = ["https://github.com/bazelbuild/apple_support/archive/0.5.0.tar.gz"],
)
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "32d124878cd49775d84f59ba90440c8b23b7c775aec8fec1978f751c76ddee8a",
strip_prefix = "rules_swift-0.7.0",
urls = ["https://github.com/bazelbuild/rules_swift/archive/0.7.0.tar.gz"],
)
sha256 = "31aad005a9c4e56b256125844ad05eb27c88303502d74138186f9083479f93a6",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.8.0/rules_swift.0.8.0.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.2.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.2.0.zip"],
)
# Use swift_rules_dependencies to fetch the tolchains.
# Since we defined all the "git_repository" rules above, the following call will
# skip redefining them.
strip_prefix = "swift-protobuf-1.4.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.4.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
swift_rules_dependencies()

View File

@ -799,8 +799,8 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
const auto& op_type = op->operation.Name();
auto op_name =
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
auto* desc =
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
std::unique_ptr<TF_OperationDescription> desc(
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()));
VLOG(1) << "Adding attrs.";
tensorflow::AttrValueMap attrs;
@ -814,30 +814,42 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
size_t inputIndex = 0;
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
// TODO(bgogul): Add support for number attributes.
DCHECK(input_arg.number_attr().empty())
<< "Number attributes is not implemented yet.";
if (input_arg.type_list_attr().empty()) {
if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) {
auto symbolic_input =
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
if (!status->status.ok()) return nullptr;
TF_AddInput(desc, symbolic_input);
TF_AddInput(desc.get(), symbolic_input);
continue;
}
const std::string& type_list_attr = input_arg.type_list_attr();
const auto& attr_value = attrs[type_list_attr];
DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
<< "Type list attribute should be a list!";
std::vector<TF_Output> list_inputs(attr_value.list().type_size());
size_t list_size = 0;
if (!input_arg.type_list_attr().empty()) {
const std::string& type_list_attr = input_arg.type_list_attr();
const auto& attr_value = attrs[type_list_attr];
CHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
<< "Type list attribute should be a list!";
list_size = attr_value.list().type_size();
} else {
CHECK(!input_arg.number_attr().empty());
const auto& attr_value = attrs[input_arg.number_attr()];
CHECK(attr_value.value_case() == tensorflow::AttrValue::kI)
<< "Number attribute should be int!";
if (attr_value.i() < 0) {
status->status = tensorflow::errors::Internal(
"Number attribute for length should be >=0!");
return nullptr;
}
list_size = attr_value.i();
}
std::vector<TF_Output> list_inputs(list_size);
for (TF_Output& list_input : list_inputs) {
list_input =
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
if (!status->status.ok()) return nullptr;
}
TF_AddInputList(desc, list_inputs.data(), list_inputs.size());
TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size());
}
auto* graph_op = TF_FinishOperation(desc, status);
auto* graph_op = TF_FinishOperation(desc.release(), status);
if (!status->status.ok()) return nullptr;
VLOG(1) << "Op finalized; setting return tensors.";

View File

@ -376,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
TFE_DeleteOp(identityn);
}
TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) {
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
TFE_OpSetAttrInt(concatv2, "N", 2);
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
constexpr size_t kNumInputs = 2;
for (size_t i = 0; i < kNumInputs; ++i) {
TFE_OpAddInput(concatv2, matrix, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
}
TFE_OpAddInput(concatv2, axis, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
AddEagerOpToGraphAndCheck(
concatv2, [this, kNumInputs](TF_Operation* graph_op) {
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1);
int64_t attrN;
TF_OperationGetAttrInt(graph_op, "N", &attrN, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
EXPECT_EQ(attrN, kNumInputs);
EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_),
kNumInputs);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
});
TFE_DeleteTensorHandle(axis);
TFE_DeleteTensorHandle(matrix);
TFE_DeleteOp(concatv2);
}
TEST_F(AddEagerOpToGraphTest,
GeneratesInternalErrorsForInvalidNumberAttributes) {
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
int num_retvals = 5;
TFE_TensorHandle* retvals[5];
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
TFE_OpSetAttrInt(concatv2, "N", -1);
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals,
&num_retvals, status_);
EXPECT_EQ(graph_op, nullptr);
EXPECT_EQ(status_->status.error_message(),
"Number attribute for length should be >=0!");
TFE_DeleteOp(concatv2);
TFE_DeleteTensorHandle(axis);
TFE_DeleteTensorHandle(matrix);
}
} // namespace
} // namespace tensorflow

View File

@ -143,7 +143,9 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
tensorflow::eager::EagerClient* eager_client;
TF_RETURN_IF_ERROR(
remote_eager_workers->GetClient(remote_worker, &eager_client));
if (eager_client == nullptr) {
return tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);

View File

@ -80,11 +80,14 @@ void ExecuteWithProfiling(bool async) {
profiler_result->length}));
string profile_proto_str = profile_proto.DebugString();
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
// device name with "stream:all" is collected by Device Tracer.
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
}
EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:CPU:0"));
// This is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
TF_DeleteBuffer(profiler_result);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);

View File

@ -21,7 +21,6 @@ from __future__ import print_function as _print_function
import os as _os
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# API IMPORTS PLACEHOLDER

View File

@ -732,44 +732,6 @@ tf_cc_test(
],
)
cc_library(
name = "xla_fusion_optimizer",
srcs = ["xla_fusion_optimizer.cc"],
hdrs = ["xla_fusion_optimizer.h"],
visibility = ["//visibility:public"],
deps = [
":common",
":compilation_passes",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "xla_fusion_optimizer_test",
srcs = ["xla_fusion_optimizer_test.cc"],
deps = [
":common",
":xla_cluster_util",
":xla_fusion_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler/utils:grappler_test",
],
)
cc_library(
name = "node_matchers",
testonly = True,

View File

@ -441,19 +441,48 @@ class MarkForCompilationPassImpl {
bool has_xla_compile_attr;
};
// Nodes that XLA can compile are put in `candidates`. Nodes put in
// `isolated_nodes` must either be unclustered or be put in trivial
// ---------------------------------------------------------------------------
// The pass proceeds in four steps, out of which `RunEdgeContractionLoop` and
// `CreateClusters` do most of the heavy lifting.
// Initialize some internal data structures.
Status Initialize();
// Contracts as many edges as possible to create XLA clusters. After this
// finishes the clustering decisions made are implicitly stored in
// `clusters_`.
Status RunEdgeContractionLoop();
// Manifests the clustering decisions into the TF graph by tagging nodes with
// an `_XlaCluster` attribute. Also some basic filter logic, like
// tf_xla_min_cluster_size, are applied here.
Status CreateClusters();
Status DumpDebugInfo();
bool IsCompilationCandidate(Node* n) const {
return compilation_candidates_.find(n) != compilation_candidates_.end();
}
// Tries to contract the edge from cluster `from` to cluster `to`. Returns
// true if successful.
StatusOr<bool> TryToContractEdge(const Cluster& from, int to);
// Tries to contract each edge from `cluster_from`. Returns true as soon as a
// single edge contraction is successful. Returns true if no edges were
// contracted.
StatusOr<bool> TryToContractEdgeFrom(Cluster* cluster_from);
// Nodes that XLA can compile are put in `candidates_`. Nodes put in
// `isolated_nodes_` must either be unclustered or be put in trivial
// single-node clusters.
StatusOr<std::pair<OrderedNodeSet, absl::flat_hash_set<Node*>>>
FindCompilationCandidates();
Status FindCompilationCandidates();
bool CompilationDisallowedByXlaCompileAttr(Node* node,
const DeviceType& jit_device_type);
Status BuildInitialClusterSet(const OrderedNodeSet& compilation_candidates,
const DeadnessAnalysis* deadness_analysis,
std::vector<UnionFind<Cluster>>* clusters,
std::deque<UnionFind<Cluster>*>* worklist);
// Populates `clusters_` and `worklist_`.
Status BuildInitialClusterSet();
StatusOr<bool> ShouldCompileClusterImpl(const Cluster& cluster);
@ -461,9 +490,7 @@ class MarkForCompilationPassImpl {
bool HasMismatchingXlaScope(Node* node_from, Node* node_to);
StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
int to_node_id, const OrderedNodeSet& compilation_candidates,
absl::Span<UnionFind<Cluster>> clusters, const GraphCycles& cycles);
StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(int to);
// Returns true if the devices in `cluster_a` and `cluster_b` are compatible
// and therefore not a hindrance for combining the two clusters into a larger
@ -481,13 +508,217 @@ class MarkForCompilationPassImpl {
OptimizerOptions::GlobalJitLevel global_jit_level_;
absl::flat_hash_map<int, bool> should_compile_cluster_cache_;
DeviceInfoCache device_info_cache_;
bool initialized_ = false;
bool edges_contracted_ = false;
bool clusters_created_ = false;
std::vector<UnionFind<Cluster>> clusters_;
std::deque<UnionFind<Cluster>*> worklist_;
GraphCycles graph_cycles_;
OrderedNodeSet compilation_candidates_;
absl::flat_hash_set<Node*> isolated_nodes_;
std::unique_ptr<DeadnessAnalysis> deadness_analysis_;
int64 iteration_count_ = 0;
};
Status IgnoreResourceOpForSafetyAnalysis(DeviceInfoCache* device_info_cache,
const Node& n, bool* ignore) {
// If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
// ignore it during resource operation safety analysis. We need this hack
// because of two reasons:
//
// 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
// 2. We don't support live-out values of type DT_RESOURCE and live-in values
// of type DT_RESOURCE that are not resource variables.
//
// Together these imply we cannot let resource variable safety analysis
// constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
// clusters: both of them will have to be clustered because of (1) and we
// won't be able to keep the edge between the two as neither the input to the
// second XLA cluster nor the output from the first XLA cluster are supported
// because of (2).
//
// TODO(b/113100872): This can be fixed if the TensorFlow representation for
// TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
// (2) would no longer hold.
if (n.assigned_device_name().empty()) {
*ignore = false;
return Status::OK();
}
TF_ASSIGN_OR_RETURN(
const XlaOpRegistry::DeviceRegistration* registration,
device_info_cache->GetCompilationDevice(n.assigned_device_name()));
if (!registration) {
*ignore = true;
} else {
*ignore = registration->cluster_resource_variable_ops_unsafely;
}
return Status::OK();
}
Status MarkForCompilationPassImpl::Initialize() {
TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_);
initialized_ = true;
TF_RETURN_IF_ERROR(FindCompilationCandidates());
if (compilation_candidates_.empty()) {
VLOG(2) << "No compilable candidates";
return Status::OK();
}
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
CreateCycleDetectionGraph(graph_, &graph_cycles_));
if (!cycle_detection_graph_ok) {
return Status::OK();
}
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
};
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
graph_, flib_def_, ignore_resource_ops, &graph_cycles_));
if (!debug_options_.ignore_deadness_checks) {
XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_));
}
// Each compilation candidate belongs to a cluster. The cluster's
// representative names the node in the 'cycles' graph that represents the
// cluster.
return BuildInitialClusterSet();
}
Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_);
edges_contracted_ = true;
// TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
// example, from the Grappler fusion pass).
while (!worklist_.empty()) {
UnionFind<Cluster>* cluster_from = worklist_.front();
worklist_.pop_front();
TF_ASSIGN_OR_RETURN(bool contracted_one_edge,
TryToContractEdgeFrom(&cluster_from->Get()));
if (contracted_one_edge) {
worklist_.push_back(cluster_from);
}
}
VLOG(1) << iteration_count_ << " iterations in inner loop for graph with "
<< compilation_candidates_.size()
<< " compilation candidates. Iterations per compilation candidate: "
<< ((1.0 * iteration_count_) / compilation_candidates_.size());
return Status::OK();
}
Status MarkForCompilationPassImpl::CreateClusters() {
TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
clusters_created_ = true;
static std::atomic<int64> cluster_sequence_num;
// Count the number of non-trivial elements in each cluster.
std::vector<int> effective_cluster_sizes(graph_->num_node_ids());
// has_functional_control_flow remembers if a cluster contains a functional
// control flow node.
std::vector<bool> has_functional_control_flow(graph_->num_node_ids());
for (const Node* n : compilation_candidates_) {
int cluster = clusters_[n->id()].Get().representative;
// We want clusters to be big enough that the benefit from XLA's
// optimizations offsets XLA related overhead (for instance we add some
// Switch/Merge nodes into the graph to implement lazy compilation). To
// this end, we don't count Identity and Constant nodes because they do not
// enable interesting optimizations by themselves.
if (!n->IsIdentity() && !n->IsConstant()) {
effective_cluster_sizes[cluster]++;
}
if (n->type_string() == "While" || n->type_string() == "If") {
has_functional_control_flow[cluster] = true;
}
}
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
if (debug_options_.dump_graphs) {
DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
}
// Mark clusters for compilation that:
// * are placed on a device that requires compilation (an XlaDevice),
// * are explicitly marked for compilation (_XlaCompile=true), or
// * have more than debug_options_.xla_min_cluster_size elements (applicable
// only if compilation is enabled, otherwise there will be no such
// candidates).
for (Node* n : compilation_candidates_) {
const Cluster& cluster = clusters_[n->id()].Get();
TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
ShouldCompileCluster(cluster));
if (!should_compile_cluster) {
continue;
}
int cluster_repr = cluster.representative;
// Compile if the user marked this node _XlaCompile=true
bool compile_attr = false;
bool marked_for_compilation = false;
if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
marked_for_compilation = compile_attr;
} else if (flib_def_->GetAttr(*n, kXlaCompileAttr, &compile_attr).ok()) {
marked_for_compilation = compile_attr;
}
// We assume that functional If and While nodes have at least
// min_cluster_size non-trivial nodes in them. It would be more principled
// to (recursively) verify this fact, but that's probably not worth the
// trouble.
if (effective_cluster_sizes[cluster_repr] >=
debug_options_.min_cluster_size ||
has_functional_control_flow[cluster_repr] || marked_for_compilation) {
string& name = cluster_names[cluster_repr];
if (name.empty()) {
name = absl::StrCat("cluster_", cluster_sequence_num++);
}
n->AddAttr(kXlaClusterAttr, name);
n->AddAttr(kXlaAlreadyClustered, true);
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
}
}
return Status::OK();
}
Status MarkForCompilationPassImpl::DumpDebugInfo() {
TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_);
if (debug_options_.dump_graphs) {
DumpPostClusteringGraphs();
}
VLogClusteringSummary();
return Status::OK();
}
StatusOr<bool>
MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
int to_node_id, const OrderedNodeSet& compilation_candidates,
absl::Span<UnionFind<Cluster>> clusters, const GraphCycles& cycles) {
const Cluster& cluster_to = clusters[to_node_id].Get();
int to) {
const Cluster& cluster_to = clusters_[to].Get();
// If any of the consumer's producers are on a different device, do not
// cluster these nodes. This prevents other work on this device from being
@ -499,14 +730,14 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
//
// TODO(b/117085735): We probably want to handle the reciprocal of this case
// where a cluster is producing data for multiple devices.
for (const auto& in_id : cycles.Predecessors(to_node_id)) {
for (const auto& in_id : graph_cycles_.Predecessors(to)) {
if (in_id >= graph_->num_node_ids()) {
continue;
}
Node* in = graph_->FindNodeId(in_id);
const Cluster& cluster_in = clusters[in_id].Get();
if (compilation_candidates.find(in) != compilation_candidates.cend()) {
const Cluster& cluster_in = clusters_[in_id].Get();
if (IsCompilationCandidate(in)) {
TF_ASSIGN_OR_RETURN(bool devices_compatible,
AreDevicesCompatible(cluster_to, cluster_in));
if (!devices_compatible) {
@ -537,20 +768,17 @@ bool MarkForCompilationPassImpl::HasMismatchingXlaScope(Node* node_from,
from_scope != to_scope;
}
Status MarkForCompilationPassImpl::BuildInitialClusterSet(
const OrderedNodeSet& compilation_candidates,
const DeadnessAnalysis* deadness_analysis,
std::vector<UnionFind<Cluster>>* clusters,
std::deque<UnionFind<Cluster>*>* worklist) {
clusters->resize(graph_->num_node_ids());
for (Node* node : compilation_candidates) {
Cluster* cluster = &(*clusters)[node->id()].Get();
Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
clusters_.resize(graph_->num_node_ids());
for (Node* node : compilation_candidates_) {
Cluster* cluster = &clusters_[node->id()].Get();
cluster->representative = node->id();
if (deadness_analysis) {
if (deadness_analysis_) {
TF_ASSIGN_OR_RETURN(
cluster->deadness_predicate,
deadness_analysis->GetPredicateFor(node, Graph::kControlSlot));
deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
}
const string& device = !node->assigned_device_name().empty()
@ -572,17 +800,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet(
}
cluster->devices.insert(device);
worklist->push_back(&(*clusters)[node->id()]);
worklist_.push_back(&clusters_[node->id()]);
}
return Status::OK();
}
StatusOr<std::pair<OrderedNodeSet, absl::flat_hash_set<Node*>>>
MarkForCompilationPassImpl::FindCompilationCandidates() {
OrderedNodeSet candidates;
absl::flat_hash_set<Node*> isolated_nodes;
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env_, TF_GRAPH_DEF_VERSION,
@ -698,17 +923,18 @@ MarkForCompilationPassImpl::FindCompilationCandidates() {
if (!is_tensor_array_or_stack_op) {
VLOG(2) << "Isolating " << node->name()
<< ": must-be-constant stateful op";
isolated_nodes.insert(node);
isolated_nodes_.insert(node);
}
}
}
candidates.insert(node);
compilation_candidates_.insert(node);
--(*debug_options_.fuel);
}
VLOG(2) << "candidates->size() = " << candidates.size();
return {{candidates, isolated_nodes}};
VLOG(2) << "compilation_candidates_.size() = "
<< compilation_candidates_.size();
return Status::OK();
}
bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
@ -753,47 +979,119 @@ bool IsShapeConsumerOp(const Node& node) {
node.type_string() == "Size";
}
Status IgnoreResourceOpForSafetyAnalysis(DeviceInfoCache* device_info_cache,
const Node& n, bool* ignore) {
// If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
// ignore it during resource operation safety analysis. We need this hack
// because of two reasons:
//
// 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
// 2. We don't support live-out values of type DT_RESOURCE and live-in values
// of type DT_RESOURCE that are not resource variables.
//
// Together these imply we cannot let resource variable safety analysis
// constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
// clusters: both of them will have to be clustered because of (1) and we
// won't be able to keep the edge between the two as neither the input to the
// second XLA cluster nor the output from the first XLA cluster are supported
// because of (2).
//
// TODO(b/113100872): This can be fixed if the TensorFlow representation for
// TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
// (2) would no longer hold.
if (n.assigned_device_name().empty()) {
*ignore = false;
return Status::OK();
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(
const Cluster& cluster_from, int to) {
Node* node_to = graph_->FindNodeId(to);
if (!IsCompilationCandidate(node_to)) {
return false;
}
TF_ASSIGN_OR_RETURN(
const XlaOpRegistry::DeviceRegistration* registration,
device_info_cache->GetCompilationDevice(n.assigned_device_name()));
if (!registration) {
*ignore = true;
} else {
*ignore = registration->cluster_resource_variable_ops_unsafely;
const Cluster& cluster_to = clusters_[to].Get();
DCHECK(cluster_from.deadness_predicate.has_value() ==
cluster_to.deadness_predicate.has_value());
if (cluster_from.deadness_predicate != cluster_to.deadness_predicate) {
return false;
}
return Status::OK();
TF_ASSIGN_OR_RETURN(bool devices_compatible,
AreDevicesCompatible(cluster_from, cluster_to));
if (!devices_compatible) {
return false;
}
if (isolated_nodes_.contains(node_to)) {
return false;
}
int from = cluster_from.representative;
Node* node_from = graph_->FindNodeId(from);
if (HasMismatchingXlaScope(node_from, node_to)) {
return false;
}
// Ops that consume shapes cannot be the root of a cluster. This is an
// optimization.
if (clusters_[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
return false;
}
// Don't exceed the maximum cluster size.
if (clusters_[from].Size() + clusters_[to].Size() >
debug_options_.max_cluster_size) {
return false;
}
TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
ClusteringWillIntroduceInterDeviceDependency(to));
if (will_introduce_cross_device_dependency) {
return false;
}
// If contracting the edge would create a cycle, bail out. However, just
// because we can't merge the clusters now does not mean we won't be able
// to merge them in the future. e.g., if we have edges 1->2, 2->3 and
// 1->3, we cannot contract edge 1->3. But if we first contract 1->2 then
// we can later contract 1->3.
return graph_cycles_.ContractEdge(from, to);
}
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdgeFrom(
Cluster* cluster_from) {
int from = cluster_from->representative;
Node* node_from = graph_->FindNodeId(from);
if (node_from->IsControlFlow()) {
// Control flow nodes aren't compilation candidates and should never
// appear.
return errors::Internal("Found control flow node in clustering worklist: ",
node_from->type_string());
}
if (!IsCompilationCandidate(node_from)) {
return false;
}
if (isolated_nodes_.count(node_from)) {
return false;
}
for (int to : graph_cycles_.Successors(from)) {
iteration_count_++;
if (to >= graph_->num_node_ids()) {
// Node is a fictitious node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
TF_ASSIGN_OR_RETURN(bool contracted_edge,
TryToContractEdge(*cluster_from, to));
if (!contracted_edge) {
continue;
}
const Cluster& cluster_to = clusters_[to].Get();
// Merge the clusters. ContractEdge uses 'from' as the number of the
// merged node, so make sure 'from' is the chosen representative.
cluster_from->devices.insert(cluster_to.devices.begin(),
cluster_to.devices.end());
if (!cluster_to.resource_op_device.empty()) {
cluster_from->resource_op_device = cluster_to.resource_op_device;
}
cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr;
clusters_[from].Merge(&clusters_[to]);
return true;
}
return false;
}
Status MarkForCompilationPassImpl::Run() {
static std::atomic<int64> cluster_sequence_num;
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
@ -801,232 +1099,10 @@ Status MarkForCompilationPassImpl::Run() {
// some one-time work.
XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1);
OrderedNodeSet compilation_candidates;
absl::flat_hash_set<Node*> isolated_nodes;
TF_ASSIGN_OR_RETURN(std::tie(compilation_candidates, isolated_nodes),
FindCompilationCandidates());
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
return Status::OK();
}
GraphCycles cycles;
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
CreateCycleDetectionGraph(graph_, &cycles));
if (!cycle_detection_graph_ok) {
return Status::OK();
}
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
};
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
graph_, flib_def_, ignore_resource_ops, &cycles));
std::unique_ptr<DeadnessAnalysis> deadness_analysis;
if (!debug_options_.ignore_deadness_checks) {
XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis));
}
// Each compilation candidate belongs to a cluster. The cluster's
// representative names the node in the 'cycles' graph that represents the
// cluster.
std::vector<UnionFind<Cluster>> clusters;
std::deque<UnionFind<Cluster>*> worklist;
TF_RETURN_IF_ERROR(BuildInitialClusterSet(
compilation_candidates, deadness_analysis.get(), &clusters, &worklist));
int64 iteration_count = 0;
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
//
// TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
// example, from the Grappler fusion pass).
while (!worklist.empty()) {
Cluster* cluster_from = &worklist.front()->Get();
int from = cluster_from->representative;
worklist.pop_front();
Node* node_from = graph_->FindNodeId(from);
if (node_from->IsControlFlow()) {
// Control flow nodes aren't compilation candidates and should never
// appear.
return errors::Internal(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
if (isolated_nodes.count(node_from)) {
continue;
}
for (int to : cycles.Successors(from)) {
iteration_count++;
if (to >= graph_->num_node_ids()) {
// Node is a fictitious node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
const Cluster& cluster_to = clusters[to].Get();
Node* node_to = graph_->FindNodeId(to);
if (compilation_candidates.find(node_to) ==
compilation_candidates.cend()) {
continue;
}
DCHECK(cluster_from->deadness_predicate.has_value() ==
cluster_to.deadness_predicate.has_value());
if (cluster_from->deadness_predicate != cluster_to.deadness_predicate) {
continue;
}
TF_ASSIGN_OR_RETURN(bool devices_compatible,
AreDevicesCompatible(*cluster_from, cluster_to));
if (!devices_compatible) {
continue;
}
if (isolated_nodes.count(node_to)) {
continue;
}
if (HasMismatchingXlaScope(node_from, node_to)) {
continue;
}
// Ops that consume shapes cannot be the root of a cluster. This is an
// optimization.
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
continue;
}
// Don't exceed the maximum cluster size.
if (clusters[from].Size() + clusters[to].Size() >
debug_options_.max_cluster_size) {
continue;
}
TF_ASSIGN_OR_RETURN(
bool will_introduce_cross_device_dependency,
ClusteringWillIntroduceInterDeviceDependency(
to, compilation_candidates, absl::MakeSpan(clusters), cycles));
if (will_introduce_cross_device_dependency) {
continue;
}
// If contracting the edge would create a cycle, bail out. However, just
// because we can't merge the clusters now does not mean we won't be able
// to merge them in the future. e.g., if we have edges 1->2, 2->3 and
// 1->3, we cannot contract edge 1->3. But if we first contract 1->2 then
// we can later contract 1->3.
if (!cycles.ContractEdge(from, to)) {
continue;
}
// Merge the clusters. ContractEdge uses 'from' as the number of the
// merged node, so make sure 'from' is the chosen representative.
cluster_from->devices.insert(cluster_to.devices.begin(),
cluster_to.devices.end());
if (!cluster_to.resource_op_device.empty()) {
cluster_from->resource_op_device = cluster_to.resource_op_device;
}
cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr;
clusters[from].Merge(&clusters[to]);
worklist.push_back(&clusters[from]);
break;
}
}
VLOG(1) << iteration_count << " iterations in inner loop for graph with "
<< compilation_candidates.size()
<< " compilation candidates. Iterations per compilation candidate: "
<< ((1.0 * iteration_count) / compilation_candidates.size());
// Count the number of non-trivial elements in each cluster.
std::vector<int> effective_cluster_sizes(graph_->num_node_ids());
// has_functional_control_flow remembers if a cluster contains a functional
// control flow node.
std::vector<bool> has_functional_control_flow(graph_->num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// We want clusters to be big enough that the benefit from XLA's
// optimizations offsets XLA related overhead (for instance we add some
// Switch/Merge nodes into the graph to implement lazy compilation). To
// this end, we don't count Identity and Constant nodes because they do not
// enable interesting optimizations by themselves.
if (!n->IsIdentity() && !n->IsConstant()) {
effective_cluster_sizes[cluster]++;
}
if (n->type_string() == "While" || n->type_string() == "If") {
has_functional_control_flow[cluster] = true;
}
}
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
if (debug_options_.dump_graphs) {
DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
}
// Mark clusters for compilation that:
// * are placed on a device that requires compilation (an XlaDevice),
// * are explicitly marked for compilation (_XlaCompile=true), or
// * have more than debug_options_.xla_min_cluster_size elements (applicable
// only if compilation is enabled, otherwise there will be no such
// candidates).
for (Node* n : compilation_candidates) {
const Cluster& cluster = clusters[n->id()].Get();
TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
ShouldCompileCluster(cluster));
if (!should_compile_cluster) {
continue;
}
int cluster_repr = cluster.representative;
// Compile if the user marked this node _XlaCompile=true
bool compile_attr = false;
bool marked_for_compilation = false;
if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
marked_for_compilation = compile_attr;
} else if (flib_def_->GetAttr(*n, kXlaCompileAttr, &compile_attr).ok()) {
marked_for_compilation = compile_attr;
}
// We assume that functional If and While nodes have at least
// min_cluster_size non-trivial nodes in them. It would be more principled
// to (recursively) verify this fact, but that's probably not worth the
// trouble.
if (effective_cluster_sizes[cluster_repr] >=
debug_options_.min_cluster_size ||
has_functional_control_flow[cluster_repr] || marked_for_compilation) {
string& name = cluster_names[cluster_repr];
if (name.empty()) {
name = absl::StrCat("cluster_", cluster_sequence_num++);
}
n->AddAttr(kXlaClusterAttr, name);
n->AddAttr(kXlaAlreadyClustered, true);
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
}
}
if (debug_options_.dump_graphs) {
DumpPostClusteringGraphs();
}
VLogClusteringSummary();
TF_RETURN_IF_ERROR(Initialize());
TF_RETURN_IF_ERROR(RunEdgeContractionLoop());
TF_RETURN_IF_ERROR(CreateClusters());
TF_RETURN_IF_ERROR(DumpDebugInfo());
return Status::OK();
}

View File

@ -519,7 +519,7 @@ Status XlaDevice::RefreshStatus() {
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter
// gets executed by a n XlaCompileOnDemandOp, which compiles it and executes
// gets executed by an XlaCompileOnDemandOp, which compiles it and executes
// it just-in-time.
OpKernel* (*factory)(OpKernelConstruction*) =
[](OpKernelConstruction* context) -> OpKernel* {

View File

@ -247,6 +247,9 @@ class XlaAssignVariableOp : public OpKernel {
data::MakeIteratorOp); \
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER( \
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
data::IteratorGetNextOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \

View File

@ -1,352 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
#include <atomic>
#include <deque>
#include <unordered_map>
#include <unordered_set>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
namespace tensorflow {
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
static bool IsShapeConsumerOp(const Node& node) {
return node.type_string() == "Shape" || node.type_string() == "ShapeN" ||
node.type_string() == "Rank" || node.type_string() == "Size";
}
// Returns true if the op can be decomposed into XLA ops for which
// there are fusible elemental implementations.
static bool IsXlaFusible(const NodeDef& node) {
static const std::unordered_set<std::string>* elementwise_ops =
new std::unordered_set<std::string>(
{// tf2xla/kernels/aggregate_ops.cc
"AddN",
// tf2xla/kernels/binary_ops.cc
"Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
"FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
"TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
"GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
"SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
// tf2xla/kernels/unary_ops.cc
"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
"Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
"Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
"Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
"Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
"Square", "Tan", "Tanh", "Real", "Imag",
// tf2xla/kernels/bcast_ops.cc
"BroadcastArgs", "BroadcastGradientArgs",
// tf2xla/kernels/bias_ops.cc
"BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
// tf2xla/kernels/cast_op.cc
"Cast",
// tf2xla/kernels/concat_op.cc
"Concat", "ConcatV2", "ConcatOffset",
// tf2xla/kernels/const_op.cc
"Const",
// tf2xla/kernels/elu_op.cc
"Elu", "EluGrad", "Selu", "SeluGrad",
// tf2xla/kernels/fill_op.cc
"Fill",
// tf2xla/kernels/identity_op.cc
"Identity", "IdentityN", "PreventGradient",
"StopGradient", /*"Snapshot",*/
// tf2xla/kernels/index_ops.cc
"ArgMax", "ArgMin",
// tf2xla/kernels/mirror_pad_op.cc
"MirrorPad",
// tf2xla/kernels/one_hot_op.cc
"OneHot",
// tf2xla/kernels/pack_op.cc
"Pack",
// tf2xla/kernels/pad_op.cc
"Pad", "PadV2",
// tf2xla/kernels/relu_op.cc
"Relu", "Relu6", "ReluGrad", "Relu6Grad",
// tf2xla/kernels/reshape_op.cc
"Reshape",
// tf2xla/kernels/reverse_op.cc
"Reverse", "ReverseV2",
// tf2xla/kernels/reverse_sequence_op.cc
"ReverseSequence",
// tf2xla/kernels/shape_op.cc
"Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
"ZerosLike", "OnesLike",
// tf2xla/kernels/slice_op.cc
"Slice",
// tf2xla/kernels/split_op.cc
"Split", "SplitV",
// tf2xla/kernels/strided_slice_op.cc
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
// tf2xla/kernels/tile_ops.cc
"Tile",
// tf2xla/kernels/transpose_op.cc
"Transpose", "InvertPermutation",
// tf2xla/kernels/unpack_op.cc
"Unpack"});
return elementwise_ops->count(node.op()) > 0;
}
Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
const grappler::GrapplerItem& item,
GraphDef* output) {
VLOG(2) << "Here at fusion optimizer";
// TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op.
// Once that happens, the expected interaction between this optimizer and when
// the global_jit_level is set is as follows: Fusion optimizer will replace
// appropriate fusion clusters with XlaLaunch nodes. The remaining graph can
// be further compiled where possible via mark_for_compilation_pass. Note that
// this might lead to inefficient clustering, and it is best to use either the
// fusion optimizer or the global_jit flag, and not combine the two.
// Create a Graph out of GraphDef. This is required currently because the
// helpers around clustering, encapsulation etc work on graphs.
FunctionLibraryDefinition function_library(OpRegistry::Global(),
item.graph.library());
Graph graph(function_library);
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
shape_refiner.set_require_shape_inference_fns(false);
shape_refiner.set_disable_constant_propagation(true);
ImportGraphDefOptions options;
// Graph optimization happens at the late stage of graph execution, when
// colocation constraints are already validated previously and the device
// placement of nodes has also completed, so there is no need to validate
// colocation constraints again.
options.validate_colocation_constraints = false;
options.validate_shape = false;
TF_RETURN_IF_ERROR(
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
// Collect nodes that can be fused via XLA, while ignoring those that
// explicitly ask for XLA: (*) nodes that are marked to be compiled
// explicitly. (*) nodes assigned to XLA device.
OrderedNodeSet compilation_candidates;
for (Node* node : graph.op_nodes()) {
// If there is a _XlaCompile annotation, ignore the node if it is
// true. Nodes are marked with this attr via experimental_jit_scope, and
// will be handled by the mark_for_compilation pass.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
if (status.ok() && compile) {
continue;
}
// If there is already a _XlaCluster annotation, ignore the node. Nodes are
// marked with this attr to indicate they are already part of a cluster and
// hence ignored.
status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile);
if (status.ok()) {
continue;
}
// If there is an explicit XLA device placement, ignore the node.
DeviceType device_type("");
TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
if (device_type.type_string().find("XLA") != string::npos) continue;
// Assume all fusible ops are registered.
// TODO(hpucha): Check for registration if possible.
if (!IsXlaFusible(node->def())) {
continue;
}
// XLA does not offer guaranteed aliasing between the input and output of
// the XLA cluster so it can't implement the forward-tensor-ref semantic.
// Leave such nodes out of XLA clusters.
if (HasForwardedRefInput(*node)) {
continue;
}
compilation_candidates.insert(node);
}
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
*output = item.graph;
return Status::OK();
}
GraphCycles cycles;
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
CreateCycleDetectionGraph(&graph, &cycles));
if (!cycle_detection_graph_ok) {
return Status::OK();
}
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
&graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));
// TODO(hpucha): Make clustering more robust. There are two known issues that
// we need to mitigate: (a) Non-resource variables can cause deadlocks
// when clustering changes order of execution. See b/77263461 for a specific
// example. (b) Queue operations can also cause deadlocks. See b/77261498 for
// example.
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.
int representative = -1;
};
// Each compilation candidate belongs to a cluster. The cluster's
// representative names the node in the 'cycles' graph that represents the
// cluster.
std::vector<UnionFind<Cluster>> clusters(graph.num_node_ids());
std::deque<UnionFind<Cluster>*> worklist;
for (Node* node : compilation_candidates) {
Cluster& cluster = clusters[node->id()].Get();
cluster.representative = node->id();
worklist.push_back(&clusters[node->id()]);
}
std::unique_ptr<DeadnessAnalysis> deadness_analysis;
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness_analysis));
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle. This is a simplified
// version of the clustering in mark_for_compilation_pass that also deals with
// nodes that are explicitly tagged to be compiled/clustered.
while (!worklist.empty()) {
int from = worklist.front()->Get().representative;
worklist.pop_front();
Node* node_from = graph.FindNodeId(from);
if (node_from->IsControlFlow()) {
// Control flow nodes aren't compilation candidates and should never
// appear.
return errors::Internal(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
for (int to : cycles.Successors(from)) {
if (to >= graph.num_node_ids()) {
// Node is a "frame" node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
Node* node_to = graph.FindNodeId(to);
if (compilation_candidates.find(node_to) ==
compilation_candidates.cend()) {
continue;
}
// Do not cluster across devices.
if (node_from->def().device() != node_to->def().device()) {
VLOG(2) << "Devices " << node_from->def().device() << " "
<< node_to->def().device();
VLOG(2) << "Device names " << node_from->assigned_device_name() << " "
<< node_to->assigned_device_name();
continue;
}
// Ops that consume shapes cannot be the root of a cluster. This is an
// optimization.
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
continue;
}
TF_ASSIGN_OR_RETURN(
DeadnessAnalysis::DeadnessPredicate pred_from,
deadness_analysis->GetPredicateFor(node_from, Graph::kControlSlot));
TF_ASSIGN_OR_RETURN(
DeadnessAnalysis::DeadnessPredicate pred_to,
deadness_analysis->GetPredicateFor(node_to, Graph::kControlSlot));
if (pred_from != pred_to) {
continue;
}
// If contracting the edge would create a cycle, bail out.
// However, just because we can't merge the clusters now does not mean
// we won't be able to merge them in the future.
// e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
// 1->3. But if we first contract 1->2 then we can later contract 1->3.
if (!cycles.ContractEdge(from, to)) continue;
// Merge the clusters. ContractEdge uses 'from' as the number of the
// merged node, so make sure 'from' is the chosen representative.
clusters[from].Merge(&clusters[to]);
worklist.push_back(&clusters[from]);
break;
}
}
// Count the number of non-trivial elements in each cluster.
std::vector<int> effective_cluster_sizes(graph.num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// Identity nodes will be removed if the node gets marked for compilation.
// Therefore we don't want to count them towards the effective cluster size.
if (n->def().op() != "Identity") {
effective_cluster_sizes[cluster]++;
}
}
const int min_cluster_size = 2;
int num_clusters = 0;
for (auto size : effective_cluster_sizes) {
if (size >= min_cluster_size) {
VLOG(3) << "Cluster " << num_clusters << " " << size;
num_clusters++;
}
}
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
// Sequence number generator to ensure clusters have unique names.
static std::atomic<int64> cluster_sequence_num;
for (Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// Compile if this is a cluster of >= min_cluster_size compilable operators.
if (effective_cluster_sizes[cluster] >= min_cluster_size) {
string& name = cluster_names[cluster];
if (name.empty()) {
name = absl::StrCat("cluster_", cluster_sequence_num++);
}
n->AddAttr(kXlaClusterAttr, name);
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
}
}
graph.ToGraphDef(output);
return Status::OK();
}
REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion");
} // namespace tensorflow

View File

@ -1,49 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
// Optimizes graphs by fusing ops where possible, resulting in more efficient
// execution.
class XlaFusionOptimizer : public grappler::CustomGraphOptimizer {
public:
XlaFusionOptimizer() {}
~XlaFusionOptimizer() override {}
Status Init(
const RewriterConfig_CustomGraphOptimizer* config = nullptr) override {
return Status::OK();
}
string name() const override { return "xla-fusion"; };
Status Optimize(grappler::Cluster* cluster,
const grappler::GrapplerItem& item,
GraphDef* output) override;
void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item,
const GraphDef& optimize_output, double result) override {
// Nothing to do for XlaFusionOptimizer.
}
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_

View File

@ -1,208 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
class XlaFusionOptimizerTest : public grappler::GrapplerTest {
protected:
std::unordered_map<string, string> GetClusters(const GraphDef& graph) {
std::unordered_map<string, string> ids;
for (const NodeDef& node : graph.node()) {
string cluster;
if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) {
CHECK(!cluster.empty());
ids[node.name()] = cluster;
}
}
return ids;
}
};
TEST_F(XlaFusionOptimizerTest, Chains) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
Node* d =
ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_EQ(clusters["E"], clusters["F"]);
EXPECT_NE(clusters["B"], clusters["E"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST_F(XlaFusionOptimizerTest, FusibleOps) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
"Placeholder",
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* b = ops::SourceOp(
"Placeholder",
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C"));
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(2, clusters.size());
EXPECT_EQ(clusters["C"], clusters["E"]);
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
"Placeholder",
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* b = ops::SourceOp(
"Placeholder",
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* c = ops::BinaryOp(
"Add", a, b,
builder.opts().WithName("C").WithDevice("/device:XLA_CPU"));
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
ops::UnaryOp("Cos", e,
builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_TRUE(clusters.empty());
}
TEST_F(XlaFusionOptimizerTest, UncompilableCycles) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b =
ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_TRUE(clusters.empty());
}
TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(3, clusters.size());
EXPECT_EQ(clusters["A"], clusters["B"]);
EXPECT_EQ(clusters["A"], clusters["C"]);
}
TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) {
Scope root = Scope::NewRootScope().ExitOnError();
Output var_handle =
ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({}));
Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f);
Output begin = ops::Const(root.WithOpName("begin"), 0);
Output end = ops::Const(root.WithOpName("end"), 1);
Output strides = ops::Const(root.WithOpName("strides"), 1);
ops::ResourceStridedSliceAssign assign_1(
root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign);
ops::ResourceStridedSliceAssign assign_2(
root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign);
root.graph()->AddControlEdge(assign_1.operation.node(),
assign_2.operation.node());
grappler::GrapplerItem item;
root.graph()->ToGraphDef(&item.graph);
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_NE(clusters["assign_1"], clusters["assign_2"]);
}
} // namespace
} // namespace tensorflow

View File

@ -23,6 +23,7 @@ import itertools
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
@ -1041,6 +1042,62 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([2], dtype=np.int64),
expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
def testBatchMatMulBroadcast(self):
"""Tests broadcasting behavior of BatchMatMul."""
with compat.forward_compatibility_horizon(2019, 4, 19):
# [2, 3] @ [1, 3, 4] -> [1, 2, 4]
self._testBinary(
math_ops.matmul,
np.array([[10, 20, 30], [11, 21, 31]], dtype=np.float32),
np.array([[[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]]],
dtype=np.float32),
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
dtype=np.float32))
# [1, 2, 3] @ [3, 4] -> [1, 2, 4]
self._testBinary(
math_ops.matmul,
np.array([[[10, 20, 30], [11, 21, 31]]], dtype=np.float32),
np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]],
dtype=np.float32),
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
dtype=np.float32))
# [2, 1, 3] @ [3, 1] -> [2, 1, 1]
self._testBinary(
math_ops.matmul,
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
np.array([[1], [2], [3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [2, 1, 3] @ [1, 3] -> [2, 1, 1] (adjoint_b)
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_b=True),
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
np.array([[1, 2, 3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [2, 3, 1] @ [3, 1] -> [2, 1, 1] (adjoint_a)
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_a=True),
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
np.array([[1], [2], [3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [2, 3, 1] @ [1, 3] -> [2, 1, 1] (adjoint_a and adjoint_b)
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_a=True, adjoint_b=True),
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
np.array([[1, 2, 3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [5, 1, 2, 3] @ [1, 7, 3, 4] -> [5, 7, 2, 4]
self._testBinary(
math_ops.matmul,
np.ones([5, 1, 2, 3], dtype=np.float32),
np.ones([1, 7, 3, 4], dtype=np.float32),
expected=np.full([5, 7, 2, 4], 3, dtype=np.float32))
# [4, 5, 1, 2, 3] @ [1, 1, 3, 5] -> [4, 5, 1, 2, 5]
self._testBinary(
math_ops.matmul,
np.full([4, 5, 1, 2, 3], 2., dtype=np.float32),
np.full([1, 1, 3, 5], 3., dtype=np.float32),
expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
def testPad(self):
for dtype, pad_type in itertools.product(
self.numeric_types, [np.int32, np.int64]):

View File

@ -387,11 +387,18 @@ class TensorArrayTest(xla_test.XLATestCase):
def fn():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
return ta.write(-1, np.int32(7)).flow
return ta.write(-1, constant_op.constant(7)).flow
# Test writing the wrong datatype.
with self.assertRaisesOpError(
"TensorArray dtype is float but op has dtype int32"):
# TODO(b/129870929): Remove InvalidArgumentError/second regexp after all
# callers provide proper init dtype.
with self.assertRaisesRegexp(
(ValueError, errors.InvalidArgumentError),
r"("
r"conversion requested dtype float32 for Tensor with dtype int32"
r"|"
r"TensorArray dtype is float but op has dtype int32"
r")"):
xla.compile(fn)[0].eval()
@test_util.disable_control_flow_v2("b/124334096 verify dtype")

View File

@ -173,6 +173,7 @@ tf_cuda_library(
name = "trt_resources",
srcs = [
"utils/trt_int8_calibrator.cc",
"utils/trt_lru_cache.cc",
"utils/trt_resources.cc",
],
hdrs = [

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <map>
#include <memory>
@ -1350,11 +1351,19 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input,
// the dims are unknown or need to be inferred. And we don't do further checks
// but rely on the caller to not make mistakes.
// Otherwise we do simple check to make sure the total sizes are the same.
if (AreDimsStaticWithDifferentSize(input_dims, dims, input.is_tensor())) {
// If an input is a weight, it is going to become a tensor via
// CreateConstantLayer. So we can treat it as a tensor for
// AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors.
if (AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
return errors::InvalidArgument(
"Incompatible shapes: ", DebugString(input_dims), " vs. ",
DebugString(dims));
}
// ConstantLayer requires static shapes (cannot infer -1).
if (input.is_weights() && !HasStaticShape(dims)) {
return errors::InvalidArgument("Shape is not fully defined: ",
DebugString(dims));
}
if (validation_only) {
*tensor = nullptr;
return Status::OK();
@ -1614,7 +1623,7 @@ struct LambdaFactory {
switch (op) {
case OP_CATEGORY::RSQRT: {
VLOG(2) << "RSQRT GETS DONE";
return [](T t) -> T { return 1.0 / sqrt(t); };
return [](T t) -> T { return 1.0 / std::sqrt(t); };
}
case OP_CATEGORY::NEG:
return [](T t) -> T { return -t; };
@ -1633,7 +1642,7 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
case OP_CATEGORY::RSQRT: {
VLOG(2) << "RSQRT GETS DONE";
return [](Eigen::half t) {
return Eigen::half(1.0 / sqrt(static_cast<float>(t)));
return Eigen::half(1.0 / std::sqrt(static_cast<float>(t)));
};
}
case OP_CATEGORY::NEG:
@ -4237,6 +4246,95 @@ Status ConvertTopK(OpConverterParams* params) {
return Status::OK();
}
Status ConvertDepthSpaceShuffle(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
TFAttrs attrs(node_def);
const int block_size = attrs.get<int64>("block_size");
if (block_size < 2) {
return errors::InvalidArgument("Block size must be 2 or greater, at ",
node_def.name());
}
const string data_format = attrs.get<string>("data_format");
if (data_format != "NCHW" && data_format != "NHWC") {
return errors::Unimplemented("Data format ", data_format,
" is not supported, at ", node_def.name());
}
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
if (dims.nbDims != 3) {
return errors::InvalidArgument("The input to ", node_def.op(),
" must be rank 4, at ", node_def.name());
}
const int num_channels = data_format == "NCHW" ? dims.d[0] : dims.d[2];
const int h = data_format == "NCHW" ? dims.d[1] : dims.d[0];
const int w = data_format == "NCHW" ? dims.d[2] : dims.d[1];
// Get shuffle parameters.
nvinfer1::Dims first_shuffle_shape;
nvinfer1::Permutation transpose_perm;
nvinfer1::Dims second_shuffle_shape;
if (node_def.op() == "DepthToSpace") {
if (num_channels % (block_size * block_size) != 0) {
return errors::InvalidArgument(
"Number of channels must be divisible by block_size*block_size, at ",
node_def.name());
}
// First Reshape [C, H, W] - > [r, r, C/(r*r), H, W]
first_shuffle_shape = {
/*nbDims=*/5,
/*d=*/{block_size, block_size, num_channels / (block_size * block_size),
h, w}};
// Transpose [r, r, C/(r*r), H, W] -> [C/(r*r), H, r, W, r]
transpose_perm = {2, 3, 0, 4, 1};
// Second Reshape [C/(r*r), H, r, W, r] -> [C/(r*r), H * r, W * r]
second_shuffle_shape =
nvinfer1::DimsCHW(num_channels / (block_size * block_size),
h * block_size, w * block_size);
} else if (node_def.op() == "SpaceToDepth") {
if (h % block_size != 0 || w % block_size != 0) {
return errors::InvalidArgument(
"Width and height must be divisible by block_size, at ",
node_def.name());
}
// First Reshape [C, H, W] -> [C, H/r, r, W/r, r]
first_shuffle_shape = {/*nbDims=*/5,
/*d=*/{num_channels, h / block_size, block_size,
w / block_size, block_size}};
// Transpose [C, H/r, r, W/r, r] -> [r, r, C, H/r, W/r]
transpose_perm = {2, 4, 0, 1, 3};
// Second Reshape [r, r, C, H/r, W/r] -> [C*r*r, H/r, W/r]
second_shuffle_shape = nvinfer1::DimsCHW(
num_channels * block_size * block_size, h / block_size, w / block_size);
}
if (params->validation_only) return Status::OK();
nvinfer1::IShuffleLayer* first_shuffle =
params->converter->network()->addShuffle(*inputs.at(0).tensor());
TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name());
if (data_format == "NHWC") {
first_shuffle->setFirstTranspose({2, 0, 1});
}
first_shuffle->setReshapeDimensions(first_shuffle_shape);
first_shuffle->setSecondTranspose(transpose_perm);
nvinfer1::IShuffleLayer* second_shuffle =
params->converter->network()->addShuffle(*first_shuffle->getOutput(0));
TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name());
second_shuffle->setReshapeDimensions(second_shuffle_shape);
if (data_format == "NHWC") {
second_shuffle->setSecondTranspose({1, 2, 0});
}
params->converter->MarkQuantizationRangesAsInferrable(
inputs.at(0).tensor(), first_shuffle->getOutput(0));
params->converter->MarkQuantizationRangesAsInferrable(
first_shuffle->getOutput(0), second_shuffle->getOutput(0));
params->outputs->push_back(TRT_TensorOrWeights(second_shuffle->getOutput(0)));
return Status::OK();
}
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
Status ConvertCombinedNMS(OpConverterParams* params) {
TF_RETURN_IF_ERROR(
@ -4416,6 +4514,7 @@ static void RegisterValidatableOpConverters(
(*registration)["Const"] = ConvertConst;
(*registration)["Conv2D"] = ConvertConv2D;
(*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
(*registration)["DepthToSpace"] = ConvertDepthSpaceShuffle;
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
(*registration)["ExpandDims"] = ConvertExpandDims;
(*registration)["GatherV2"] = ConvertGather;
@ -4430,6 +4529,7 @@ static void RegisterValidatableOpConverters(
(*registration)["Slice"] = ConvertSlice;
(*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed
(*registration)["Softmax"] = ConvertSoftmax;
(*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle;
(*registration)["Split"] = ConvertSplit;
(*registration)["Square"] = ConvertSquare;
(*registration)["Squeeze"] = ConvertSqueeze;

View File

@ -212,6 +212,19 @@ std::vector<CType> InitTestVector(int size, CType start_value = CType(0)) {
return res;
}
template <typename InCType, typename OutCType>
struct StaticCaster {
OutCType operator()(InCType in) const { return static_cast<OutCType>(in); }
};
template <typename InCType, typename OutCType>
std::vector<OutCType> CastTestVector(const std::vector<InCType>& vals) {
std::vector<OutCType> res(vals.size());
std::transform(vals.begin(), vals.end(), res.begin(),
StaticCaster<InCType, OutCType>());
return res;
}
// Fake ITensor implementation for testing purposes.
class FakeITensor : public nvinfer1::ITensor {
public:
@ -721,19 +734,25 @@ TEST_F(ConverterTest, TransposeTensor) {
ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions());
}
void TestPrepareTensorForShape_Tensor(
const std::vector<int>& tensor_dims, const std::vector<int>& reshape_dims,
const std::vector<int>& expected_tensor_dims, Converter* converter,
void TestPrepareTensorForShape(
const std::vector<int>& input_dims, const std::vector<int>& reshape_dims,
const std::vector<int>& expected_tensor_dims, bool input_is_tensor,
Converter* converter, TrtWeightStore* weight_store,
error::Code expected_code = error::OK,
const char* expected_error_msg_substr = nullptr) {
nvinfer1::ITensor* input_tensor = converter->network()->addInput(
"", nvinfer1::DataType::kFLOAT, GetTestDims(tensor_dims));
TRT_TensorOrWeights input;
if (input_is_tensor) {
input = TRT_TensorOrWeights(converter->network()->addInput(
"", nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
} else {
input = TRT_TensorOrWeights(weight_store->GetTempWeights(
nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
}
nvinfer1::ITensor* output_tensor = nullptr;
for (bool validation_only : {false, true}) {
const Status status = converter->PrepareTensorForShape(
TRT_TensorOrWeights(input_tensor), GetTestDims(reshape_dims),
validation_only, &output_tensor);
input, GetTestDims(reshape_dims), validation_only, &output_tensor);
if (expected_code == error::OK) {
TF_EXPECT_OK(status);
if (validation_only) {
@ -748,49 +767,45 @@ void TestPrepareTensorForShape_Tensor(
}
}
TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
// Shape size doesn't match.
TEST_F(ConverterTest, PrepareTensorForShape) {
for (bool input_is_tensor : {true, false}) {
// Shape size doesn't match.
Reset();
TestPrepareTensorForShape({2, 3, 5}, {2, 3, 6}, {}, input_is_tensor,
converter_.get(), weight_store_,
error::INVALID_ARGUMENT, "Incompatible shapes");
// Regular shape.
Reset();
TestPrepareTensorForShape({2, 3, 5}, {10, 3}, {10, 3}, input_is_tensor,
converter_.get(), weight_store_);
// Reshape to zero rank.
Reset();
TestPrepareTensorForShape({1, 1}, {}, {}, input_is_tensor, converter_.get(),
weight_store_);
}
// Tensor input with zero rank.
Reset();
TestPrepareTensorForShape_Tensor({2, 3, 5}, {2, 3, 6}, {}, converter_.get(),
error::INVALID_ARGUMENT,
"Incompatible shapes");
TestPrepareTensorForShape({}, {1, 1}, {1, 1}, /*input_is_tensor=*/true,
converter_.get(), weight_store_);
// TODO(aaroey): we should check the case where uninferred dimensions are
// not an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
// Infer shape, ok.
// Infer tensor shape, ok.
Reset();
TestPrepareTensorForShape_Tensor({2, 3, 5}, {-1, 2}, {15, 2},
converter_.get());
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
/*input_is_tensor=*/true, converter_.get(),
weight_store_);
// Regular shape.
// Infer weight shape, should fail.
Reset();
TestPrepareTensorForShape_Tensor({2, 3, 5}, {10, 3}, {10, 3},
converter_.get());
// Input with zero rank.
Reset();
TestPrepareTensorForShape_Tensor({}, {1, 1}, {1, 1}, converter_.get());
// Reshape to zero rank.
Reset();
TestPrepareTensorForShape_Tensor({1, 1}, {}, {}, converter_.get());
}
TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
TRT_ShapedWeights weights = weight_store_->GetTempWeights(
nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
nvinfer1::ITensor* output_tensor = nullptr;
for (bool validation_only : {false, true}) {
TF_EXPECT_OK(converter_->PrepareTensorForShape(
TRT_TensorOrWeights(weights), GetTestDims({10, 3}), validation_only,
&output_tensor));
if (validation_only) {
EXPECT_EQ(nullptr, output_tensor);
} else {
ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
}
}
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
/*input_is_tensor=*/false, converter_.get(),
weight_store_, error::INVALID_ARGUMENT,
"Shape is not fully defined");
}
TEST_F(ConverterTest, MaybeUpdateBatchSize) {
@ -4910,6 +4925,279 @@ TEST_F(OpConverterTest, ConvertArgMinMax) {
// TestConvertArgMinMax<ops::ArgMax, DT_INT32>(this);
}
// Get the NodeDef for DepthToSpace or SpaceToSpace.
template <typename OpType>
NodeDef GetDepthSpaceShuffleNodeDef(DataType dtype, int block_size,
string data_format) {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), dtype);
auto attrs = OpType::DataFormat(data_format);
auto shuffle = OpType(s.WithOpName("my_shuffle"), input, block_size, attrs);
return shuffle.operation.node()->def();
}
template <typename CType>
struct DepthSpaceShuffleTestParams {
std::vector<int> input_dims;
std::vector<CType> input_value;
int block_size;
string data_format;
std::vector<int> expected_output_dims;
std::vector<CType> expected_output;
};
template <typename OpType, DataType dtype, typename CType>
void TestConvertDepthSpaceShuffle(
OpConverterTest* test,
const std::vector<DepthSpaceShuffleTestParams<CType>>& params) {
for (int i = 0; i < params.size(); ++i) {
test->Reset();
NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
dtype, params[i].block_size, params[i].data_format);
test->AddTestTensor("input", params[i].input_dims, 1,
TfDataTypeToTrt(dtype));
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_shuffle", &output));
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
output.tensor()->getDimensions());
DataVec input_data{{"input", test::AsTensor<CType>(params[i].input_value)}};
DataVec output_data{{"my_shuffle", ConstructTensor<CType>(
params[i].expected_output.size())}};
test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(params[i].expected_output));
}
}
template <DataType dtype>
void TestConvertDepthToSpace(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const std::vector<CType> common_input = InitTestVector<CType>(16);
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
{
/*input_shape=*/{4, 2, 2},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{1, 4, 4},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15}),
},
{
/*input_shape=*/{2, 2, 4},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{4, 4, 1},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
},
{
/*input_shape=*/{16, 1, 1},
/*input_value=*/common_input,
/*block_size=*/4,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{1, 4, 4},
/*expected_output=*/InitTestVector<CType>(16),
},
{
/*input_shape=*/{2, 2, 8},
/*input_value=*/InitTestVector<CType>(32),
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{4, 4, 2},
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
9, 10, 11, 4, 5,
6, 7, 12, 13, 14,
15, 16, 17, 18, 19,
24, 25, 26, 27, 20,
21, 22, 23, 28, 29,
30, 31}),
},
};
TestConvertDepthSpaceShuffle<ops::DepthToSpace, dtype, CType>(test, params);
}
TEST_F(OpConverterTest, ConvertDepthToSpace) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_shuffle", "DepthToSpace", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"DepthToSpace got 0 inputs but expected 1, at my_shuffle");
}
{
// Input is a weight, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"The input \"input\" for DepthToSpace must be a "
"tensor, at my_shuffle");
}
{
// Input rank != 4
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
AddTestTensor("input", {16, 32});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"The input to DepthToSpace must be rank 4, at my_shuffle");
}
{
// Channels not divisible by block_size, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 3, "NCHW");
AddTestTensor("input", {16, 32, 32});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Number of channels must be divisible by "
"block_size*block_size, at my_shuffle");
}
{
// Unsupported format, should fail.
Reset();
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(
DT_FLOAT, 2, "NCHW_VECT_C");
AddTestTensor("input", {16, 32, 32});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Data format NCHW_VECT_C is not supported, at my_shuffle");
}
TestConvertDepthToSpace<DT_FLOAT>(this);
TestConvertDepthToSpace<DT_HALF>(this);
TestConvertDepthToSpace<DT_INT32>(this);
}
template <DataType dtype>
void TestConvertSpaceToDepth(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const std::vector<CType> common_input = InitTestVector<CType>(16);
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
{
/*input_shape=*/{1, 4, 4},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{4, 2, 2},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15}),
},
{
/*input_shape=*/{4, 4, 1},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{2, 2, 4},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
},
{
/*input_shape=*/{1, 4, 4},
/*input_value=*/common_input,
/*block_size=*/4,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{16, 1, 1},
/*expected_output=*/InitTestVector<CType>(16),
},
{
/*input_shape=*/{4, 4, 2},
/*input_value=*/InitTestVector<CType>(32),
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{2, 2, 8},
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
9, 10, 11, 4, 5,
6, 7, 12, 13, 14,
15, 16, 17, 18, 19,
24, 25, 26, 27, 20,
21, 22, 23, 28, 29,
30, 31}),
},
};
TestConvertDepthSpaceShuffle<ops::SpaceToDepth, dtype, CType>(test, params);
}
TEST_F(OpConverterTest, ConvertSpaceToDepth) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_shuffle", "SpaceToDepth", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"SpaceToDepth got 0 inputs but expected 1, at my_shuffle");
}
{
// Input is a weight, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"The input \"input\" for SpaceToDepth must be a "
"tensor, at my_shuffle");
}
{
// Input rank != 4
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
AddTestTensor("input", {16, 32});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"The input to SpaceToDepth must be rank 4, at my_shuffle");
}
{
// Width not divisble by block_size, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
AddTestTensor("input", {16, 9, 32});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Width and height must be divisible by "
"block_size, at my_shuffle");
}
{
// Height not divisble by block_size, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
AddTestTensor("input", {16, 32, 9});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Width and height must be divisible by "
"block_size, at my_shuffle");
}
{
// Unsupported format, should fail.
Reset();
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(
DT_FLOAT, 2, "NCHW_VECT_C");
AddTestTensor("input", {16, 32, 32});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Data format NCHW_VECT_C is not supported, at my_shuffle");
}
TestConvertSpaceToDepth<DT_FLOAT>(this);
TestConvertSpaceToDepth<DT_HALF>(this);
TestConvertSpaceToDepth<DT_INT32>(this);
}
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
@ -290,17 +291,17 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
VLOG(1) << "Executing TRT calibration: " << name();
helper->Ref();
core::ScopedUnref sc(helper);
auto res_mgr = ctx->resource_manager();
TRTCalibrationResource* calib_res = nullptr;
OP_REQUIRES_OK(ctx,
res_mgr->LookupOrCreate(
"TF_TRT_Calibration", name(),
ctx->resource_manager()->LookupOrCreate(
"TF-TRT-Calibration", name(),
reinterpret_cast<SerializableResourceBase**>(&calib_res),
{[ctx, this](SerializableResourceBase** cr) -> Status {
return this->AllocateCalibrationResources(ctx, cr);
}}));
core::ScopedUnref calib_sc(calib_res);
int num_inputs = ctx->num_inputs();
// TODO(laigd): need to check that input shape matches.
// Pass input data to calibrator
std::unordered_map<string, void*> input_data;
for (int i = 0; i < num_inputs; i++) {
@ -522,10 +523,22 @@ EngineContext* TRTEngineOp::GetEngine(
// TODO(tmorris): using first input to get batch size - is this reliable?
const int batch_size = input_shapes[0].dim_size(0);
// Get engine cache
// Canonicalize the op name by removing the scopes if any. This is mainly
// because in TFv2, the function graph can be instantiated in various ways and
// it'll insert scope names to the name of the TRTEngineOps, which will result
// in many different engine caches if we use the instantiated op name
// directly, but we still want all of them share the same cache (if they were
// representing the same subgraph).
absl::string_view resource_name = name();
size_t last_slash = resource_name.find_last_of('/');
if (last_slash != absl::string_view::npos) {
resource_name.remove_prefix(last_slash + 1);
}
// Get engine cache.
TRTEngineCacheResource* cache_res = nullptr;
auto status = ctx->resource_manager()->LookupOrCreate(
"TRTEngineCache", name(), &cache_res,
"TF-TRT-Engine-Cache", string(resource_name), &cache_res,
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
return Status::OK();
@ -632,12 +645,13 @@ EngineContext* TRTEngineOp::GetEngine(
cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
return &empty_context;
}
VLOG(1) << "Conversion is done";
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
engine->createExecutionContext());
cache.emplace(engine_input_shapes,
absl::make_unique<EngineContext>(std::move(engine),
std::move(exec_context)));
VLOG(1) << "Added new engine to cache of " << name()
<< ". Cache size: " << cache.size();
}
return cache.at(engine_input_shapes).get();
}

View File

@ -459,7 +459,7 @@ Status SegmentGraph(const Graph* tf_graph,
}
LOG(INFO) << msg << "(For more information see "
<< "https://docs.nvidia.com/deeplearning"
<< "/dgx/integrate-tf-trt/index.html#support-ops).";
<< "/dgx/tf-trt-user-guide/index.html#supported-ops).";
// The segmentation algorithm below visits nodes in reverse topological order
// and attempts to merge nodes along output edges. That means that subgraphs

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include <sstream>
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/mutex.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx,
size_t capacity)
: cache_(capacity) {
auto device = ctx->device();
auto alloc = device->GetAllocator(AllocatorAttributes());
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
allocator_ = nullptr;
} else {
allocator_.reset(new TRTDeviceAllocator(alloc));
}
}
TRTEngineCacheResource::~TRTEngineCacheResource() {
VLOG(1) << "Destroying TRTEngineCacheResource...";
}
string TRTEngineCacheResource::DebugString() const {
std::stringstream oss;
using std::dec;
using std::endl;
using std::hex;
oss << "TRTEngineCacheResource: ";
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
oss << "LRUCache = " << hex << &cache_ << dec << endl;
oss << "Containing " << cache_.size() << " entries: " << endl;
for (const auto& item : cache_) {
mutex_lock lock(item.second->mu);
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
<< "ICudaEngine: " << item.second->cuda_engine.get() << ", "
<< "IExecutionContext: " << item.second->execution_context.get() << dec
<< endl;
}
return oss.str();
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -141,36 +141,11 @@ struct EngineContext {
class TRTEngineCacheResource : public ResourceBase {
public:
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity)
: cache_(capacity) {
auto device = ctx->device();
auto alloc = device->GetAllocator(AllocatorAttributes());
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
allocator_ = nullptr;
} else {
allocator_.reset(new TRTDeviceAllocator(alloc));
}
}
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity);
string DebugString() const override {
std::stringstream oss;
using std::dec;
using std::endl;
using std::hex;
oss << "TRTEngineCacheResource: ";
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
oss << "LRUCache = " << hex << &cache_ << dec << endl;
oss << "Containing " << cache_.size() << " entries: " << endl;
for (const auto& item : cache_) {
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
<< "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", "
<< "IExecutionContext: " << item.second.get()->execution_context.get()
<< dec << endl;
}
return oss.str();
}
~TRTEngineCacheResource() override;
string DebugString() const override;
// Keep device allocator for TRT.
std::unique_ptr<TRTBaseAllocator> allocator_;

View File

@ -44,6 +44,7 @@ class BatchMatMulOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("BatchMatMul"), BatchMatMulOp);
REGISTER_XLA_OP(Name("BatchMatMulV2"), BatchMatMulOp);
} // namespace
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
// XLA implementations of Categorical op.
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@ -140,8 +141,6 @@ class StatelessCategoricalOp : public CategoricalOp {
xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type,
XlaOpKernelContext* ctx) override {
xla::XlaOp seed = ctx->Input(2);
auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
xla::XlaBuilder* builder = ctx->builder();
if (uniform_shape.element_type() == xla::BF16) {
@ -150,8 +149,8 @@ class StatelessCategoricalOp : public CategoricalOp {
// We want a number in (0, 1) rather than [0, 1) or (0, 1]:
// * log(-log(0)) is ∞.
// * log(-log(1)) is -∞.
auto uniforms = xla::StatelessRngUniform(
{seed0, seed1}, uniform_shape,
xla::XlaOp uniforms = StatelessRngUniform(
seed, uniform_shape,
xla::MinPositiveNormalValue(builder, uniform_shape.element_type()),
xla::One(builder, uniform_shape.element_type()));
return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type);

View File

@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Ops for 2D convolution.
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@ -293,10 +294,9 @@ xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
return attrs;
}
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
xla::XlaOp conv_input,
xla::XlaOp filter,
const ConvOpAttrs& attrs) {
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter,
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
auto* builder = conv_input.builder();
@ -377,12 +377,14 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
return xla::ConvGeneralDilated(
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
dims,
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count);
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count,
/*batch_group_count=*/1, precision_config);
}
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
int num_dims = attrs.num_spatial_dims + 2;
@ -456,13 +458,14 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
/*feature_group_count=*/
attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
filter_shape.dimensions(attrs.num_spatial_dims + 1)
: feature_group_count);
: feature_group_count,
/*batch_group_count=*/1, precision_config);
}
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
StringPiece type_string, xla::XlaOp activations,
const xla::Shape& filter_shape, xla::XlaOp gradients,
const ConvOpAttrs& attrs) {
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
auto* builder = activations.builder();
@ -612,7 +615,8 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
rhs_dilation, dnums,
/*feature_group_count=*/feature_group_count,
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1);
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1,
precision_config);
if (!use_batch_group_count && attrs.depthwise) {
filter_backprop = ContractFilterForDepthwiseBackprop(

View File

@ -53,17 +53,19 @@ struct ConvOpAttrs {
// Creates a new XLA forward or backward convolution with the given inputs and
// attributes.
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
xla::XlaOp conv_input,
xla::XlaOp filter,
const ConvOpAttrs& attrs);
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
StringPiece type_string, xla::XlaOp conv_input, xla::XlaOp filter,
const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config = nullptr);
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config = nullptr);
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
StringPiece type_string, xla::XlaOp activations,
const xla::Shape& filter_shape, xla::XlaOp gradients,
const ConvOpAttrs& attrs);
const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config = nullptr);
} // namespace tensorflow

View File

@ -22,6 +22,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
// Returns a tensor containing 'shape' random values uniformly distributed in
// the range [minval, maxval). The raw random bits are generated by the given
// `bit_generator` and converted to the requested data type and range. This
// routine requires 2 32-bit integer seeds and currently only supports 'shape's
// of type F32, S32 and S64.
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
xla::XlaOp minval, xla::XlaOp maxval);
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
// It masks the last 16 bit. With normal rounding, values near "maxval" would be

View File

@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel {
TensorShape shape;
int64 product = 1;
int unknown_index = -1;
bool shape_has_zero_dim = false;
for (int d = 0; d < num_dims; ++d) {
const int32 size = shape_input[d];
if (size == -1) {
@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel {
unknown_index, " and ", d));
unknown_index = d;
shape.AddDim(1);
} else if (size == 0) {
// We don't include zero-sized dimension in product, so that we can
// still calculate number of elements for non-zero-sized dimensions and
// therefore infer their shapes.
shape.AddDim(size);
shape_has_zero_dim = true;
} else {
OP_REQUIRES(ctx, size >= 0,
errors::InvalidArgument(
@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel {
}
}
if (unknown_index != -1) {
OP_REQUIRES(
ctx, product > 0,
errors::InvalidArgument("Reshape cannot infer the missing input size "
"for an empty tensor unless all specified "
"input sizes are non-zero"));
const int64 missing = input_shape.num_elements() / product;
OP_REQUIRES(
ctx, product * missing == input_shape.num_elements(),
errors::InvalidArgument(
"Input to reshape is a tensor with ", input_shape.num_elements(),
" values, but the requested shape requires a multiple of ",
product));
int64 input_num_elements = 1;
bool input_has_zero_dim = false;
for (int dim = 0; dim < input_shape.dims(); dim++) {
// For zero dimension, we don't count it into `input_num_elements`
// unless `sizes` has no zero dimension, so we are still able to
// infer shapes for other dimensions.
if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) {
input_num_elements *= input_shape.dim_size(dim);
} else {
input_has_zero_dim = true;
}
}
const int64 missing = input_num_elements / product;
if (!input_has_zero_dim) {
OP_REQUIRES(
ctx, product * missing == input_num_elements,
errors::InvalidArgument(
"Input to reshape is a tensor with ", input_num_elements,
" values, but the requested shape requires a multiple of ",
product));
}
shape.set_dim(unknown_index, missing);
}
OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),

View File

@ -35,127 +35,50 @@ limitations under the License.
namespace tensorflow {
namespace {
std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter(
xla::XlaOp counter, const int64 size) {
auto builder = counter.builder();
auto input_u64 = Iota(builder, xla::U64, size);
input_u64 = input_u64 + counter;
counter = counter + xla::ConstantR0<uint64>(builder, size);
return std::make_pair(xla::Uint64ToUint32s(input_u64), counter);
}
// `StatelessRngUniformU32` uses ThreeFry2x32s counter space too
// wastefully, only able to generate 2^32*2 int32 numbers for each key, while
// the real capacity is 2^64*2. Counter-space efficiency is important for
// stateful ops, hence the following 2 new functions.
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32(
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
auto builder = key.builder();
const int64 size = xla::ShapeUtil::ElementsIn(shape);
const int64 half_size = xla::CeilOfRatio<int64>(size, 2);
const bool size_is_odd = (half_size * 2 != size);
auto inputs_counter = GetInputsFromCounter(counter, half_size);
auto inputs = inputs_counter.first;
counter = inputs_counter.second;
auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key));
if (size_is_odd) {
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
}
auto result = ConcatInDim(builder, outputs, 0);
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
counter);
}
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64(
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
const int64 size = xla::ShapeUtil::ElementsIn(shape);
auto inputs_counter = GetInputsFromCounter(counter, size);
auto inputs = inputs_counter.first;
counter = inputs_counter.second;
auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
auto result = Uint32sToUint64(outputs);
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
counter);
}
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
xla::XlaOp counter,
const xla::Shape& shape,
xla::XlaOp minval,
xla::XlaOp maxval) {
auto builder = key.builder();
xla::RngOutput StatefulRngUniform(xla::XlaOp key, xla::XlaOp initial_state,
const xla::Shape& shape, xla::XlaOp minval,
xla::XlaOp maxval) {
xla::PrimitiveType type = shape.element_type();
switch (type) {
case xla::F32: {
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
auto bits = bits_counter.first;
counter = bits_counter.second;
return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval),
counter);
}
case xla::U32: // fall through
case xla::S32: {
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
auto bits = bits_counter.first;
counter = bits_counter.second;
return std::make_pair(
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32),
counter);
}
case xla::U64: // fall through
case xla::S64: {
auto bits_counter = StatefulRngUniformU64(key, counter, shape);
auto bits = bits_counter.first;
counter = bits_counter.second;
return std::make_pair(
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64),
counter);
}
case xla::F32:
return xla::UniformF32Distribution(
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
case xla::U32:
case xla::S32:
case xla::U64:
case xla::S64:
return UniformIntDistribution(
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
default:
return std::make_pair(
builder->ReportError(xla::Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by "
"StatefulRngUniform; got: %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
counter);
return {key.builder()->ReportError(xla::Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by "
"StatefulRngUniform; got %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
initial_state};
}
}
template <typename A, typename B, typename A2>
std::pair<A2, B> map_first(std::function<A2(A)> f, std::pair<A, B> p) {
return std::make_pair(f(p.first), p.second);
}
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformFullInt(
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
xla::RngOutput StatefulRngUniformFullInt(xla::XlaOp key,
xla::XlaOp initial_state,
const xla::Shape& shape) {
xla::PrimitiveType type = shape.element_type();
xla::RngOutput output = xla::ThreeFryBitGenerator(key, initial_state, shape);
switch (type) {
case xla::U32:
return StatefulRngUniformU32(key, counter, shape);
case xla::S32: {
// Needs explicit function type because of type-inference failure.
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
return BitcastConvertType(x, xla::S32);
};
return map_first(f, StatefulRngUniformU32(key, counter, shape));
}
case xla::U64:
return StatefulRngUniformU64(key, counter, shape);
case xla::S64: {
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
return BitcastConvertType(x, xla::S64);
};
return map_first(f, StatefulRngUniformU64(key, counter, shape));
}
return output;
case xla::S32:
case xla::S64:
output.value = BitcastConvertType(output.value, type);
return output;
default:
auto builder = key.builder();
return std::make_pair(
builder->ReportError(xla::Unimplemented(
return {
key.builder()->ReportError(xla::Unimplemented(
"Types other than U32, S32, U64 and S64 are not implemented by "
"StatefulRngUniformFullInt; got: %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
counter);
initial_state};
}
}
@ -177,15 +100,15 @@ xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
0);
}
using sampler_return_type = xla::StatusOr<std::pair<xla::XlaOp, xla::XlaOp>>;
using SamplerReturnType = xla::StatusOr<xla::RngOutput>;
// A helper function containing the common part of several kernels below.
// Precondition: 'algorithm' and 'shape' are compile-time constants.
Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
int alg_input_idx, int shape_input_idx,
std::function<sampler_return_type(xla::XlaOp, xla::XlaOp,
TensorShape)> const&
sample_with_threefry) {
Status CompileImpl(
XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
int shape_input_idx,
std::function<SamplerReturnType(xla::XlaOp, xla::XlaOp, TensorShape)> const&
sampler) {
auto alg_shape = ctx->InputShape(alg_input_idx);
if (alg_shape.dims() != 0) {
return errors::InvalidArgument("algorithm must be of shape [], not ",
@ -215,24 +138,22 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
TensorShape shape;
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
static constexpr int COUNTER_SIZE = 1;
auto counter = BitcastConvertType(
xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64);
static constexpr int kStateSize = 1;
auto state = BitcastConvertType(
xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
auto key = BitcastConvertType(
xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}),
{}),
xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
xla::U64);
auto status_or_value = sample_with_threefry(counter, key, shape);
auto status_or_value = sampler(state, key, shape);
if (!status_or_value.ok()) {
return status_or_value.status();
}
auto output_counter = status_or_value.ConsumeValueOrDie();
auto output = output_counter.first;
counter = output_counter.second;
ctx->SetOutput(0, output);
auto builder = ctx->builder();
var = ConcatScalars(builder, {counter, key});
xla::RngOutput value_state = status_or_value.ConsumeValueOrDie();
state = value_state.state;
ctx->SetOutput(0, value_state.value);
xla::XlaBuilder* builder = ctx->builder();
var = ConcatScalars(builder, {state, key});
xla::PrimitiveType state_element_type;
TF_RETURN_IF_ERROR(
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
@ -252,23 +173,22 @@ class StatefulUniformOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
auto sample_with_threefry = [builder, this](
xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
xla::XlaBuilder* builder = ctx->builder();
auto sampler = [builder, this](xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
auto uniform_counter = StatefulRngUniform(
key, counter, xla_shape, xla::ConstantR0<float>(builder, 0.0),
xla::RngOutput uniform_state = StatefulRngUniform(
key, state, xla_shape, xla::ConstantR0<float>(builder, 0.0),
xla::ConstantR0<float>(builder, 1.0));
auto uniform = uniform_counter.first;
counter = uniform_counter.second;
xla::XlaOp uniform = uniform_state.value;
state = uniform_state.state;
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
return {{uniform, counter}};
return {{uniform, state}};
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
/*shape_input_idx=*/2, sampler));
}
private:
@ -293,30 +213,20 @@ class StatefulStandardNormalOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
auto sample_with_threefry =
auto sampler =
// Needs explicit lambda return type because it fails to be inferred.
[builder, this](xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
[this](xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
auto uniform_counter = StatefulRngUniform(
key, counter, xla_shape,
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
xla::ConstantR0<float>(builder, 1.0));
auto uniform = uniform_counter.first;
counter = uniform_counter.second;
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
auto normal =
xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
normal = MaybeConvertF32ToBF16(normal, dtype_);
return {{normal, counter}};
xla::RngOutput value_state = xla::NormalF32Distribution(
key, state, xla::ThreeFryBitGenerator, xla_shape);
xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
return {{normal, value_state.state}};
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
/*shape_input_idx=*/2, sampler));
}
private:
@ -341,27 +251,27 @@ class StatefulTruncatedNormalOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
auto sample_with_threefry =
xla::XlaBuilder* builder = ctx->builder();
auto sampler =
// Needs explicit lambda return type because it fails to be inferred.
[builder, this](xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
[builder, this](xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
auto uniform_counter = StatefulRngUniform(
key, counter, xla_shape,
xla::RngOutput uniform_result = StatefulRngUniform(
key, state, xla_shape,
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
xla::One(builder, xla_shape.element_type()));
auto uniform = uniform_counter.first;
counter = uniform_counter.second;
xla::XlaOp uniform = uniform_result.value;
state = uniform_result.state;
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
return {{truncated_normal, counter}};
return {{truncated_normal, state}};
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
/*shape_input_idx=*/2, sampler));
}
private:
@ -388,11 +298,11 @@ class StatefulUniformIntOp : public XlaOpKernel {
xla::XlaOp minval = ctx->Input(3);
xla::XlaOp maxval = ctx->Input(4);
auto sample_with_threefry = [minval, maxval, this](
xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
return StatefulRngUniform(key, counter, xla_shape, minval, maxval);
return StatefulRngUniform(key, state, xla_shape, minval, maxval);
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
@ -420,12 +330,11 @@ class StatefulUniformFullIntOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto sample_with_threefry = [this](
xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
auto sample_with_threefry = [this](xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
return StatefulRngUniformFullInt(key, counter, xla_shape);
return StatefulRngUniformFullInt(key, state, xla_shape);
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,

View File

@ -36,8 +36,8 @@ namespace tensorflow {
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
if (dtype == DT_BFLOAT16) {
xla::XlaBuilder* builder = input.builder();
auto output = xla::BitcastConvertType(input, xla::U32) &
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
xla::BF16);
} else {
@ -45,22 +45,36 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
}
}
xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
}
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
xla::XlaOp minval, xla::XlaOp maxval) {
xla::XlaBuilder* builder = seeds.builder();
// A wrapper of xla::StatelessRngUniform. Returns an op that produces random
// values with uniform distribution in the range [minval, maxval) for the given
// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
// S64 are implemented.
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused,
xla::XlaOp seed, xla::XlaOp minval,
xla::XlaOp maxval) {
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval);
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
xla::PrimitiveType type = shape.element_type();
switch (type) {
case xla::F32:
return xla::UniformF32Distribution(key, initial_state,
xla::ThreeFryBitGenerator, minval,
maxval, shape)
.value;
case xla::S32: // fall through
case xla::S64:
return UniformIntDistribution(key, initial_state,
xla::ThreeFryBitGenerator, minval, maxval,
shape)
.value;
break;
default:
return builder->ReportError(xla::Unimplemented(
"Types other than F32, S32 and S64 are not implemented by "
"StatelessRngUniform; got %s",
xla::primitive_util::LowercasePrimitiveTypeName(type)));
}
}
namespace {
@ -86,8 +100,8 @@ class StatelessRandomUniformOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
xla::XlaOp uniform = StatelessRandomUniformImpl(
xla_shape, dtype_, seed, xla::ConstantR0<float>(builder, 0.0),
xla::XlaOp uniform = StatelessRngUniform(
seed, xla_shape, xla::ConstantR0<float>(builder, 0.0),
xla::ConstantR0<float>(builder, 1.0));
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
ctx->SetOutput(0, uniform);
@ -136,8 +150,8 @@ class StatelessRandomUniformIntOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
xla::XlaOp uniform =
StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval);
xla::XlaOp uniform = StatelessRngUniform(seed, xla_shape, minval, maxval);
ctx->SetOutput(0, uniform);
}
@ -170,14 +184,20 @@ class StatelessRandomNormalOp : public XlaOpKernel {
errors::InvalidArgument("seed must have shape [2], not ",
seed_shape.DebugString()));
xla::XlaOp seed = ctx->Input(1);
xla::XlaBuilder* builder = ctx->builder();
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
xla::XlaOp uniform = StatelessRandomUniformImpl(
xla_shape, dtype_, seed,
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
xla::ConstantR0<float>(builder, 1.0));
xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform);
xla::XlaBuilder* builder = seed.builder();
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp normal =
xla::NormalF32Distribution(key, initial_state,
xla::ThreeFryBitGenerator, xla_shape)
.value;
normal = MaybeConvertF32ToBF16(normal, dtype_);
ctx->SetOutput(0, normal);
}
@ -215,8 +235,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
xla::XlaOp uniform = StatelessRandomUniformImpl(
xla_shape, dtype_, seed,
xla::XlaOp uniform = StatelessRngUniform(
seed, xla_shape,
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
xla::One(builder, xla_shape.element_type()));
xla::XlaOp truncated_normal = TruncatedNormal(uniform);

View File

@ -165,7 +165,7 @@ Status RewriteAndPruneGraph(
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
PruneForReverseReachability(graph, retval_nodes);
PruneForReverseReachability(graph, std::move(retval_nodes));
FixupSourceAndSinkEdges(graph);
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.

View File

@ -287,6 +287,7 @@ tf_cc_test(
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
"@com_google_absl//absl/hash:hash_testing",
"@com_google_absl//absl/strings",
],
)

View File

@ -32,8 +32,12 @@ XlaOp RotateLeftU32(XlaOp v, int distance) {
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
}
} // namespace
// The internal state of the Three Fry implementation.
using ThreeFry2x32State = std::array<XlaOp, 2>;
// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
XlaBuilder* builder = input[0].builder();
key[0] = BitcastConvertType(key[0], U32);
@ -104,56 +108,68 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
return x;
}
// Returns the inputs with unique counter values for ThreeFry2x32.
ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) {
ThreeFry2x32State inputs;
inputs[0] = Iota(builder, U32, size);
inputs[1] = inputs[0] + ConstantR0<uint32>(builder, size);
return inputs;
}
XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
XlaBuilder* builder = key[0].builder();
const int64 size = ShapeUtil::ElementsIn(shape);
const int64 half_size = CeilOfRatio<int64>(size, 2);
const bool size_is_odd = (half_size * 2 != size);
ThreeFry2x32State inputs = GetInputs(half_size, builder);
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
if (size_is_odd) {
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
}
auto result = ConcatInDim(builder, outputs, 0);
return Reshape(result, AsInt64Slice(shape.dimensions()));
}
// Converts a uint64 to two uint32s.
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
auto builder = u64.builder();
auto const32 = ConstantR0WithType(builder, U64, 32);
auto fst = ConvertElementType(u64, U32);
auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
XlaBuilder* builder = u64.builder();
XlaOp const32 = ConstantR0WithType(builder, U64, 32);
XlaOp fst = ConvertElementType(u64, U32);
XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
return {fst, snd};
}
// Converts two uint32s to a uint64.
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
auto builder = u32s[0].builder();
XlaBuilder* builder = u32s[0].builder();
return ConvertElementType(u32s[0], U64) |
ShiftLeft(ConvertElementType(u32s[1], U64),
ConstantR0WithType(builder, U64, 32));
}
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
XlaBuilder* builder = key[0].builder();
const int64 size = ShapeUtil::ElementsIn(shape);
ThreeFry2x32State inputs = GetInputs(size, builder);
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
// low 32 bit: outputs[0], high 32 bit: outputs[1]
auto result = Uint32sToUint64(outputs);
return Reshape(result, AsInt64Slice(shape.dimensions()));
// Given the initial state and the request number of random numbers to be
// generated, returns the input for the random number generator and a new state.
std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
XlaOp initial_state, const int64 size) {
XlaBuilder* builder = initial_state.builder();
XlaOp input_u64 = Iota(builder, U64, size);
input_u64 = input_u64 + initial_state;
XlaOp new_state = initial_state + ConstantR0<uint64>(builder, size);
return std::make_pair(Uint64ToUint32s(input_u64), new_state);
}
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
XlaBuilder* builder = bits.builder();
// Generates random 32bits with the given shape using the Three Fry
// implementation. Returns the random bits and the new state.
RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
XlaBuilder* builder = key.builder();
const int64 size = ShapeUtil::ElementsIn(shape);
const int64 half_size = CeilOfRatio<int64>(size, 2);
const bool size_is_odd = (half_size * 2 != size);
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
GetThreeFryInputsAndUpdatedState(initial_state, half_size);
ThreeFry2x32State inputs = inputs_state.first;
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
if (size_is_odd) {
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
}
XlaOp result = ConcatInDim(builder, outputs, 0);
return {Reshape(result, AsInt64Slice(shape.dimensions())),
inputs_state.second};
}
// Generates random 64bits with the given shape using the Three Fry
// implementation. Returns the random bits and the new state.
RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
const int64 size = ShapeUtil::ElementsIn(shape);
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
GetThreeFryInputsAndUpdatedState(initial_state, size);
ThreeFry2x32State inputs = inputs_state.first;
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
XlaOp result = Uint32sToUint64(outputs);
return {Reshape(result, AsInt64Slice(shape.dimensions())),
inputs_state.second};
}
XlaOp ConvertRandomBitsToUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
XlaBuilder* builder = bits.builder();
// Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
// forces the random bits into the mantissa.
constexpr int kFloatBits = 32;
@ -161,50 +177,95 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
bits = ShiftRightLogical(
bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
auto floats = BitcastConvertType(bits, F32);
XlaOp values = BitcastConvertType(bits, F32);
// We have a floating point number in the range [1.0, 2.0).
// Subtract 1.0f to shift to the range [0.0, 1.0)
floats = floats - ConstantR0<float>(builder, 1.0f);
values = values - ConstantR0<float>(builder, 1.0f);
// Multiply and add to shift to the range [minval, maxval).
return floats * (maxval - minval) + minval;
return values * (maxval - minval) + minval;
}
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
PrimitiveType type, PrimitiveType unsigned_type) {
XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
PrimitiveType type,
PrimitiveType unsigned_type) {
XlaBuilder* builder = bits.builder();
auto range = BitcastConvertType(maxval, unsigned_type) -
BitcastConvertType(minval, unsigned_type);
auto dist = Rem(bits, range);
auto dist_div_2 =
XlaOp range = BitcastConvertType(maxval, unsigned_type) -
BitcastConvertType(minval, unsigned_type);
XlaOp dist = Rem(bits, range);
XlaOp dist_div_2 =
ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
return minval + BitcastConvertType(dist_div_2, type) +
BitcastConvertType(dist - dist_div_2, type);
}
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
XlaOp minval, XlaOp maxval) {
XlaBuilder* builder = seeds[0].builder();
XlaOp UniformToNormalUsingSqrtErfInv(XlaOp uniform) {
return ScalarLike(uniform, std::sqrt(2.0)) * ErfInv(uniform);
}
} // namespace
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
const Shape& shape) {
PrimitiveType type = shape.element_type();
switch (type) {
case F32: {
auto bits = StatelessRngUniformU32(seeds, shape);
return StatelessRngUniformF32(bits, minval, maxval);
}
case S32: {
auto bits = StatelessRngUniformU32(seeds, shape);
return StatelessRngUniformInt(bits, minval, maxval, type, U32);
}
case S64: {
auto bits = StatelessRngUniformU64(seeds, shape);
return StatelessRngUniformInt(bits, minval, maxval, type, U64);
}
case F32:
case U32:
case S32:
return ThreeFryRngBit32(key, initial_state, shape);
case U64:
case S64:
return ThreeFryRngBit64(key, initial_state, shape);
default:
return builder->ReportError(Unimplemented(
"Types other than F32, S32 and S64 are not implemented by "
"StatelessRngUniform."));
return {key.builder()->ReportError(Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by ThreeFryBitGenerator; got %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
initial_state};
}
}
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator, XlaOp minval,
XlaOp maxval, const Shape& shape) {
DCHECK_EQ(shape.element_type(), F32);
RngOutput bits_state = bit_generator(key, initial_state, shape);
XlaOp bits = bits_state.value;
XlaOp new_state = bits_state.state;
return {ConvertRandomBitsToUniformF32(bits, minval, maxval), new_state};
}
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator, XlaOp minval,
XlaOp maxval, const Shape& shape) {
RngOutput bits_state = bit_generator(key, initial_state, shape);
XlaOp bits = bits_state.value;
XlaOp new_state = bits_state.state;
PrimitiveType type = shape.element_type();
PrimitiveType unsigned_type;
if (type == U32 || type == S32) {
unsigned_type = U32;
} else {
DCHECK(type == U64 || type == S64);
unsigned_type = U64;
}
return {
ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
new_state};
}
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator,
const Shape& shape) {
DCHECK_EQ(shape.element_type(), F32);
XlaBuilder* builder = key.builder();
RngOutput bits_state = UniformF32Distribution(
key, initial_state, bit_generator,
ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
ConstantR0<float>(builder, 1.0), shape);
XlaOp normal = UniformToNormalUsingSqrtErfInv(bits_state.value);
return {normal, bits_state.state};
}
} // namespace xla

View File

@ -23,37 +23,52 @@ limitations under the License.
namespace xla {
// Records the bits and state generated by a random number generator.
struct RngOutput {
XlaOp value;
XlaOp state;
};
// A BitGenerator returns random bits and updated random bit generator state.
//
// key: is a value input to a random number generator that can affect the
// sequence of number it will generate. A random number generator constructs
// its seed using the key and the initial state. The tf2xla bridge passes the
// seed operand of a tensorflow random operation as a key to the random bit
// generator, for example.
// initial_state: initial_state is the initial state of the current random
// number generation. It could be 0 for a stateless random operation, and
// the returned state from a previous execution for a stateful random
// operation.
// shape: the shape of the random bits.
using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state,
const xla::Shape& shape)>;
// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
using ThreeFry2x32State = std::array<XlaOp, 2>;
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key);
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
const xla::Shape& shape);
// Returns a tensor containing 'shape' random values uniformly distributed in
// the range [minval, maxval). Requires 2 32-bit integer seeds.
// Currently only 'shape's of type F32, S32 and S64 are implemented.
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
XlaOp minval, XlaOp maxval);
// Uses the given bit generator to generate random bits and then converts the
// random bits to random numbers of uniform distribution in the given range.
// Returns the random numbers and the state of the random number generator.
// This function is for shape with float element type.
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator, XlaOp minval,
XlaOp maxval, const xla::Shape& shape);
// Converts a 32-bit (signed or unsigned) integer random number `bits` into a
// float32 in the range [minval, maxval).
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
// Similar to UniformF32Distribution but for shape with integer element types.
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator, XlaOp minval,
XlaOp maxval, const xla::Shape& shape);
// Converts an integer random number 'bits' of type 'type' to a random number
// in the range [minval, maxval), of the same type. 'unsigned_type' is the
// unsigned version of 'type' (could be the same) with the same bit width.
// The algorithm is the same one that TF uses right now, but it's
// uniform only when maxval - minval is a divisor of the range that bits is
// generated from.
// TODO(b/72573764): Generate real uniform integer distribution.
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
PrimitiveType type, PrimitiveType unsigned_type);
// The following 2 functions, for converting between one uint64 and two uint32s,
// use the contract "lower 32 bits for the first uint32, higher 32 bits for the
// second".
ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
// Uses the given bit generator to generate random bits and then converts the
// random bits to random numbers of normal distribution.
// Returns the random numbers and the state of the random number generator.
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator,
const xla::Shape& shape);
} // namespace xla

View File

@ -69,6 +69,11 @@ class Tile {
// combined with the next minor dimension before tiling is applied.
static constexpr int64 kCombineDimension = std::numeric_limits<int64>::min();
template <typename H>
friend H AbslHashValue(H h, const Tile& t) {
return H::combine(std::move(h), t.dimensions_);
}
private:
// The bounds of the tile.
std::vector<int64> dimensions_;
@ -212,6 +217,13 @@ class Layout {
element_size_in_bits_ = 0;
}
template <typename H>
friend H AbslHashValue(H h, const Layout& l) {
return H::combine(std::move(h), l.format_, l.minor_to_major_,
l.max_sparse_elements_, l.tiles_,
l.element_size_in_bits_);
}
private:
// The format of this layout.
Format format_ = INVALID_FORMAT;

View File

@ -109,6 +109,7 @@ tf_pybind_extension(
":xrt",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@pybind11",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
@ -81,9 +82,12 @@ StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
DeviceMemoryAllocator* allocator = client->backend().memory_allocator();
TransferManager* transfer_manager = client->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
Shape shape, transfer_manager->ChooseCompactLayoutForShape(tree.shape));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
transfer_manager->AllocateScopedShapedBuffer(
tree.shape, allocator, device_ordinal));
shape, allocator, device_ordinal));
TF_ASSIGN_OR_RETURN(auto stream,
client->mutable_backend()->BorrowStream(device_ordinal));
TF_RETURN_IF_ERROR(
@ -91,7 +95,7 @@ StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
auto it = tree.leaves.begin();
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(tree.shape)) {
ShapeUtil::GetLeafShapes(shape)) {
TF_RET_CHECK(it != tree.leaves.end());
ShapedBuffer leaf(
indexed_shape.shape,
@ -224,10 +228,7 @@ StatusOr<LocalShapedBuffer> LocalExecutableWrapper::Execute(
result_buffer_status = executable_->Run(argument_buffers, options);
if (!result_buffer_status.ok()) {
return InternalError(
"Failed running replica 0 (other replicas may have failed as well): "
"%s.",
result_buffer_status.status().ToString());
return result_buffer_status.status();
}
return LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
client_);
@ -298,10 +299,12 @@ LocalExecutableWrapper::ExecutePerReplica(
for (int replica = 0; replica < num_replicas(); ++replica) {
auto& statusor = results[replica];
if (!statusor.ok()) {
return InternalError(
"Failed running replica %d (other replicas may have failed as well): "
"%s.",
replica, statusor.status().ToString());
return AppendStatus(
statusor.status(),
absl::StrFormat(
"while running replica %d of a replicated computation (other "
"replicas may have failed as well).",
replica));
}
wrapped_results[replica] =
LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
@ -346,23 +349,53 @@ StatusOr<std::string> GetComputationHloDotGraph(
/*static*/ StatusOr<std::unique_ptr<LocalExecutableWrapper>>
LocalExecutableWrapper::Compile(const XlaComputation& computation,
const std::vector<Shape>& argument_shapes,
std::vector<Shape> argument_layouts,
const ExecutableBuildOptions* build_options,
LocalClient* client) {
tensorflow::profiler::TraceMe("LocalExecutable::Compile");
std::vector<const Shape*> argument_shape_pointers;
argument_shape_pointers.reserve(argument_shapes.size());
for (auto& argument_shape : argument_shapes) {
argument_shape_pointers.push_back(&argument_shape);
std::vector<const Shape*> argument_layout_pointers;
argument_layout_pointers.reserve(argument_layouts.size());
// Assign a default layout to any array subshapes that are missing layouts.
auto assign_layouts = [client](Shape* shape) {
return ShapeUtil::ForEachMutableSubshapeWithStatus(
shape, [&](Shape* subshape, const ShapeIndex&) {
if (subshape->IsArray() && !subshape->has_layout()) {
LayoutUtil::SetToDefaultLayout(subshape);
TF_ASSIGN_OR_RETURN(*subshape,
client->backend()
.transfer_manager()
->ChooseCompactLayoutForShape(*subshape));
}
return Status::OK();
});
};
for (Shape& layout : argument_layouts) {
argument_layout_pointers.push_back(&layout);
assign_layouts(&layout);
}
ExecutableBuildOptions options;
if (build_options != nullptr) {
options = *build_options;
}
Shape result_layout;
if (options.result_layout()) {
result_layout = *options.result_layout();
} else {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
result_layout = program_shape.result();
LayoutUtil::ClearLayout(&result_layout);
}
assign_layouts(&result_layout);
options.set_result_layout(result_layout);
TF_ASSIGN_OR_RETURN(
auto local_executable,
client->Compile(computation, argument_shape_pointers, options));
client->Compile(computation, argument_layout_pointers, options));
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
client->backend().computation_placer()->AssignDevices(
options.num_replicas(), /*computation_count=*/1));

View File

@ -82,8 +82,7 @@ class LocalExecutableWrapper {
public:
// Compiles a computation to an executable.
static StatusOr<std::unique_ptr<LocalExecutableWrapper>> Compile(
const XlaComputation& computation,
const std::vector<Shape>& argument_shapes,
const XlaComputation& computation, std::vector<Shape> argument_layouts,
const ExecutableBuildOptions* build_options, LocalClient* client);
LocalExecutableWrapper(std::unique_ptr<LocalExecutable> executable,

View File

@ -75,7 +75,9 @@ PYBIND11_MODULE(xla_extension, m) {
if (layout) {
return ShapeUtil::MakeShapeWithLayout(type, dims, *layout);
} else {
return ShapeUtil::MakeShape(type, dims);
Shape shape = ShapeUtil::MakeShape(type, dims);
shape.clear_layout();
return shape;
}
},
"Makes an array shape.", py::arg("type"), py::arg("dims"),
@ -87,7 +89,9 @@ PYBIND11_MODULE(xla_extension, m) {
.def("tuple_shapes",
static_cast<const std::vector<Shape>& (Shape::*)() const>(
&Shape::tuple_shapes))
.def("__repr__", [](const Shape& shape) { return shape.ToString(); });
.def("__repr__", [](const Shape& shape) {
return shape.ToString(/*print_layouts=*/true);
});
py::class_<ProgramShape>(m, "ProgramShape")
.def(py::init(

View File

@ -73,8 +73,7 @@ class Backend(object):
"""Destructures a tuple buffer into a sequence of buffers."""
@abc.abstractmethod
def compile(self, computation, argument_shapes, result_shape,
compile_options):
def compile(self, computation, compile_options):
"""Compiles a computation. Returns an executable."""
@abc.abstractmethod
@ -115,15 +114,19 @@ class LocalBackend(Backend):
def destructure_tuple(self, c_buffer):
return c_buffer.DestructureTuple()
def compile(self, c_computation, argument_shapes, result_shape,
compile_options):
def compile(self, c_computation, compile_options):
options = _xla.ExecutableBuildOptions()
options.num_replicas = compile_options.num_replicas
if compile_options.argument_layouts:
argument_layouts = [
s.as_xla_shape() for s in compile_options.argument_layouts
]
else:
argument_layouts = c_computation.GetProgramShape().Parameters()
if compile_options.result_layout:
options.result_layout = compile_options.result_layout.as_xla_shape()
argument_shapes = [s.as_xla_shape() for s in argument_shapes]
return _xla.LocalExecutable.Compile(c_computation, argument_shapes, options,
self.client)
return _xla.LocalExecutable.Compile(c_computation, argument_layouts,
options, self.client)
def delete_executable(self, executable):
executable.Delete()
@ -234,42 +237,6 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
source_line=lineno)
class PaddingType(enum.Enum):
VALID = 1
SAME = 2
def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
window_strides):
"""Maps PaddingType or string to pad values (list of pairs of ints)."""
if not isinstance(padding_type, (str, PaddingType)):
msg = 'padding_type must be str or PaddingType, got {}.'
raise TypeError(msg.format(type(padding_type)))
if isinstance(padding_type, str):
if padding_type.upper() == 'VALID':
padding_type = PaddingType.VALID
elif padding_type.upper() == 'SAME':
padding_type = PaddingType.SAME
else:
msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
raise ValueError(msg.format(padding_type))
if padding_type == PaddingType.VALID:
return [(0, 0)] * len(window_strides)
elif padding_type == PaddingType.SAME:
out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
pad_sizes = [
max((out_size - 1) * stride + filter_size - in_size, 0)
for out_size, stride, filter_size, in_size in zip(
out_shape, window_strides, rhs_dims, lhs_dims)
]
return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
else:
msg = 'Unexpected PaddingType value: {}'
raise ValueError(msg.format(padding_type))
PrimitiveType = _xla.PrimitiveType
XLA_ELEMENT_TYPE_TO_DTYPE = {
@ -303,7 +270,7 @@ def dtype_to_etype(dtype):
return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
class LocalBuffer(object):
class Buffer(object):
"""Represents a handle to data owned by XLA.
The referent is ready for use in executing a local, compiled
@ -322,7 +289,7 @@ class LocalBuffer(object):
backend = backend or get_local_backend()
pyval = require_numpy_array_layout(pyval)
cbuf = backend.buffer_from_pyval(pyval, device)
return LocalBuffer(cbuf, backend, device)
return Buffer(cbuf, backend, device)
def to_py(self):
return self.c_buffer.ToPython()
@ -344,7 +311,7 @@ class LocalBuffer(object):
result = self._backend.destructure_tuple(self.c_buffer)
self.delete()
return tuple(
LocalBuffer(sub_buffer, device=self._device, backend=self._backend)
Buffer(sub_buffer, device=self._device, backend=self._backend)
for sub_buffer in result)
def is_deleted(self):
@ -354,6 +321,11 @@ class LocalBuffer(object):
self.delete()
# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops
# compatibility with Jaxlib versions older than 0.1.13.
LocalBuffer = Buffer
class Format(enum.IntEnum):
"""Python copy of the Format protocol buffer enum."""
INVALID_FORMAT = 0
@ -396,6 +368,7 @@ class Shape(object):
@staticmethod
def from_pyval(pyval):
"""Returns a Shape that describes a tuple-tree of Numpy arrays."""
def convert(pyval):
if isinstance(pyval, tuple):
return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
@ -561,24 +534,6 @@ def require_numpy_array_layout(value):
return np.require(value, requirements=['C', 'A'])
class CompileOptions(object):
"""Python object for XLA compile options.
These options can be passed to the 'compile' step when using a local XLA
client.
"""
def __init__(self):
self.xla_dump_to = None
self.dump_hlo_pass_re = None
self.dump_hlo_module_re = None
self.dump_hlo_as_text = None
self.dump_hlo_as_proto = None
self.hlo_profile = None
self.num_replicas = 1
self.result_layout = None
def transfer_to_infeed(value, device_ordinal=0):
"""Transfers the given value into the XLA infeed queue.
@ -611,8 +566,28 @@ def transfer_from_outfeed(shape, device_ordinal=0):
"""
# TODO(phawkins): support non-default backends.
backend = get_local_backend()
return backend.client.TransferFromOutfeed(shape.as_xla_shape(),
device_ordinal)
return backend.client.TransferFromOutfeed(
shape.with_major_to_minor_layout_if_absent().as_xla_shape(),
device_ordinal)
class CompileOptions(object):
"""Python object for XLA compile options.
These options can be passed to the 'compile' step when using a local XLA
client.
"""
def __init__(self):
self.xla_dump_to = None
self.dump_hlo_pass_re = None
self.dump_hlo_module_re = None
self.dump_hlo_as_text = None
self.dump_hlo_as_proto = None
self.hlo_profile = None
self.num_replicas = 1
self.argument_layouts = None
self.result_layout = None
class Computation(object):
@ -656,55 +631,28 @@ class Computation(object):
"""
return self.computation.GetHloDotGraph()
def Compile(self,
argument_shapes=(),
compile_options=None,
layout_fn=None,
backend=None):
def Compile(self, argument_shapes=None, compile_options=None, backend=None):
"""Compiles a computation.
Computations are the result of a "ComputationBuild'ing" process.
Arguments:
argument_shapes: parameter shapes -- they are first laid out by layout_fn
if layout_fn is provided. Otherwise, the default layout for those shapes
will be used.
argument_shapes: Deprecated. Use compile_options.argument_layouts instead.
compile_options: options to use for compilation, includes an optional laid
out result shape for the computation.
layout_fn: lambda that is used to lay out the argument/result shapes.
backend: a `Backend` for which an executable should be generated.
Returns:
A Executable instance.
"""
backend = backend or self._backend or get_local_backend()
result_shape = _wrap_shape(self.computation.GetProgramShape().Result())
if layout_fn:
argument_shapes = [
shape.map_leaves(layout_fn) for shape in argument_shapes
]
result_shape = result_shape.map_leaves(layout_fn)
argument_shapes = list(argument_shapes)
compile_options = compile_options or CompileOptions()
compile_options.result_layout = result_shape
c = backend.compile(self.computation, argument_shapes, result_shape,
compile_options)
if argument_shapes:
compile_options.argument_layouts = argument_shapes
c = backend.compile(self.computation, compile_options)
return Executable(c, backend=backend)
def CompileWithExampleArguments(self,
arguments=(),
compile_options=None,
layout_fn=None,
backend=None):
return self.Compile(
argument_shapes=[Shape.from_pyval(arg) for arg in arguments],
compile_options=compile_options,
layout_fn=layout_fn,
backend=backend)
def GetProgramShape(self):
return _wrap_program_shape(self._c_computation.GetProgramShape())
@ -725,25 +673,25 @@ class Executable(object):
return self._device_ordinals
def Execute(self, arguments=(), check_for_deleted_args=True):
"""Execute on one replica with LocalBuffer arguments and return value."""
"""Execute on one replica with Buffer arguments and return value."""
if check_for_deleted_args and any(arg.is_deleted() for arg in arguments):
raise ValueError('Executing with deleted local buffer argument')
raw_args = [arg.c_buffer for arg in arguments]
output_buffer = self._backend.execute(self._c_executable, raw_args)
return LocalBuffer(
return Buffer(
output_buffer, backend=self._backend, device=self._device_ordinals[0])
def ExecutePerReplica(self, arguments=None):
"""Execute on many replicas with LocalBuffer arguments and return value.
"""Execute on many replicas with Buffer arguments and return value.
Args:
arguments: A sequence of sequences of LocalBuffers. The i'th inner
sequence comprises the arguments for execution on the i'th replica.
arguments: A sequence of sequences of Buffers. The i'th inner sequence
comprises the arguments for execution on the i'th replica.
Returns:
A list of the computation's outputs for each replica, as a LocalBuffer. If
A list of the computation's outputs for each replica, as a Buffer. If
a shallow sequence of arguments was passed in for `arguments`, then the
sole, zero'th replica's output is returned instead, as a LocalBuffer.
sole, zero'th replica's output is returned instead, as a Buffer.
"""
if arguments is None:
arguments = ((),) * len(self._device_ordinals)
@ -770,9 +718,9 @@ class Executable(object):
output_buffers = self._backend.execute_replicated(self._c_executable,
stripped_args)
# Wrap output handles in LocalBuffer instances
# Wrap output handles in Buffer instances
return tuple(
LocalBuffer(
Buffer(
output_buffer,
backend=self._backend,
device=self._device_ordinals[replica])
@ -782,7 +730,7 @@ class Executable(object):
"""Execute on one replica with Python values as arguments and output."""
def put(arg):
return LocalBuffer.from_pyval(
return Buffer.from_pyval(
arg, device=self._device_ordinals[0], backend=self._backend)
arguments = [put(arg) for arg in arguments]
@ -792,7 +740,7 @@ class Executable(object):
"""Execute on many replicas with Python values as arguments and output."""
def put(arg, device):
return LocalBuffer.from_pyval(arg, device, backend=self._backend)
return Buffer.from_pyval(arg, device, backend=self._backend)
# pylint: disable=g-complex-comprehension
arguments = [[
@ -804,6 +752,42 @@ class Executable(object):
self._backend.delete_executable(self._c_executable)
class PaddingType(enum.Enum):
VALID = 1
SAME = 2
def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
window_strides):
"""Maps PaddingType or string to pad values (list of pairs of ints)."""
if not isinstance(padding_type, (str, PaddingType)):
msg = 'padding_type must be str or PaddingType, got {}.'
raise TypeError(msg.format(type(padding_type)))
if isinstance(padding_type, str):
if padding_type.upper() == 'VALID':
padding_type = PaddingType.VALID
elif padding_type.upper() == 'SAME':
padding_type = PaddingType.SAME
else:
msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
raise ValueError(msg.format(padding_type))
if padding_type == PaddingType.VALID:
return [(0, 0)] * len(window_strides)
elif padding_type == PaddingType.SAME:
out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
pad_sizes = [
max((out_size - 1) * stride + filter_size - in_size, 0)
for out_size, stride, filter_size, in_size in zip(
out_shape, window_strides, rhs_dims, lhs_dims)
]
return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
else:
msg = 'Unexpected PaddingType value: {}'
raise ValueError(msg.format(padding_type))
class ComputationBuilder(object):
"""XLA computation builder.
@ -858,7 +842,9 @@ class ComputationBuilder(object):
Returns:
An XlaOp.
"""
return ops.Infeed(self._builder, shape.as_xla_shape())
return ops.Infeed(
self._builder,
shape.with_major_to_minor_layout_if_absent().as_xla_shape())
def Outfeed(self, operand):
"""Enqueues an outfeed op onto the computation.
@ -955,8 +941,10 @@ class ComputationBuilder(object):
if parameter_num is None:
parameter_num = next(self._parameter_numbering)
return ops.Parameter(self._builder, parameter_num, shape.as_xla_shape(),
name.encode('utf8'))
return ops.Parameter(
self._builder, parameter_num,
shape.with_major_to_minor_layout_if_absent().as_xla_shape(),
name.encode('utf8'))
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
"""Enqueues a Parameter op onto the computation.
@ -1143,9 +1131,11 @@ class ComputationBuilder(object):
pads = _convert_padding_type_to_pad_values(
padding,
self.GetShape(operand).dimensions(), window_dimensions, window_strides)
return ops.SelectAndScatterWithGeneralPadding(
operand, select.computation, window_dimensions, window_strides, pads,
source, init_value, scatter.computation)
return ops.SelectAndScatterWithGeneralPadding(operand, select.computation,
window_dimensions,
window_strides, pads, source,
init_value,
scatter.computation)
def Slice(self, operand, start_indices, limit_indices, strides=None):
"""Enqueues a slice operation onto the computation.
@ -1306,13 +1296,15 @@ class ComputationBuilder(object):
pads = _convert_padding_type_to_pad_values(
padding,
self.GetShape(operand).dimensions(), window_dimensions, window_strides)
return ops.ReduceWindowWithGeneralPadding(
operand, init_value, computation_to_apply.computation,
window_dimensions, window_strides, (), (), pads)
return ops.ReduceWindowWithGeneralPadding(operand, init_value,
computation_to_apply.computation,
window_dimensions, window_strides,
(), (), pads)
def ReduceWindowWithGeneralPadding(
self, operand, init_value, computation_to_apply, window_dimensions,
window_strides, base_dilations, window_dilations, padding):
def ReduceWindowWithGeneralPadding(self, operand, init_value,
computation_to_apply, window_dimensions,
window_strides, base_dilations,
window_dilations, padding):
"""Enqueues a windowed reduction operation onto the computation.
Args:
@ -1328,10 +1320,11 @@ class ComputationBuilder(object):
Returns:
An XlaOp representing the added ReduceWindow op.
"""
return ops.ReduceWindowWithGeneralPadding(
operand, init_value, computation_to_apply.computation,
window_dimensions, window_strides, base_dilations, window_dilations,
padding)
return ops.ReduceWindowWithGeneralPadding(operand, init_value,
computation_to_apply.computation,
window_dimensions, window_strides,
base_dilations, window_dilations,
padding)
def RngNormal(self, mu, sigma, dims):
"""Enqueues an RngNormal operation onto the computation.
@ -1709,6 +1702,7 @@ def _forward_methods_to_local_builder():
forward.__name__ = method_name
setattr(ComputationBuilder, method_name, forward)
_forward_methods_to_local_builder()

View File

@ -38,7 +38,7 @@ class ComputationTest(unittest.TestCase):
return xla_client.ComputationBuilder(name)
def _Execute(self, c, arguments):
compiled_c = c.Build().CompileWithExampleArguments(arguments)
compiled_c = c.Build().Compile()
return compiled_c.ExecuteWithPythonValues(arguments)
def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
@ -53,11 +53,15 @@ class ComputationTest(unittest.TestCase):
def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7,
def _ExecuteAndCompareClose(self,
c,
arguments=(),
expected=None,
rtol=1e-7,
atol=0):
self._ExecuteAndAssertWith(
functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol),
c, arguments, expected)
functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), c,
arguments, expected)
def NumpyArrayF32(*args, **kwargs):
@ -212,20 +216,19 @@ class ComputationsWithConstantsTest(ComputationTest):
def testShiftLeft(self):
c = self._NewComputation()
c.ShiftLeft(c.Constant(NumpyArrayS32([3])),
c.Constant(NumpyArrayS32([2])))
c.ShiftLeft(c.Constant(NumpyArrayS32([3])), c.Constant(NumpyArrayS32([2])))
self._ExecuteAndCompareClose(c, expected=[12])
def testShiftRightArithmetic(self):
c = self._NewComputation()
c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])),
c.Constant(NumpyArrayS32([1])))
c.ShiftRightArithmetic(
c.Constant(NumpyArrayS32([-2])), c.Constant(NumpyArrayS32([1])))
self._ExecuteAndCompareClose(c, expected=[-1])
def testShiftRightLogical(self):
c = self._NewComputation()
c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])),
c.Constant(NumpyArrayS32([1])))
c.ShiftRightLogical(
c.Constant(NumpyArrayS32([-1])), c.Constant(NumpyArrayS32([1])))
self._ExecuteAndCompareClose(c, expected=[2**31 - 1])
def testSum2DF64(self):
@ -396,7 +399,7 @@ class LocalBufferTest(ComputationTest):
"""Tests focusing on execution with LocalBuffers."""
def _Execute(self, c, arguments):
compiled_c = c.Build().CompileWithExampleArguments(arguments)
compiled_c = c.Build().Compile()
arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments]
result_buffer = compiled_c.Execute(arg_buffers)
return result_buffer.to_py()
@ -410,24 +413,22 @@ class LocalBufferTest(ComputationTest):
c = self._NewComputation()
c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
self._ExecuteAndCompareClose(
c,
arguments=[NumpyArrayF32(1.11)],
expected=4.25)
c, arguments=[NumpyArrayF32(1.11)], expected=4.25)
def testTwoParameterSum(self):
c = self._NewComputation()
c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)),
c.ParameterFromNumpy(NumpyArrayF32(0.)))
c.Add(
c.ParameterFromNumpy(NumpyArrayF32(0.)),
c.ParameterFromNumpy(NumpyArrayF32(0.)))
self._ExecuteAndCompareClose(
c,
arguments=[NumpyArrayF32(1.11), NumpyArrayF32(3.14)],
expected=4.25)
c, arguments=[NumpyArrayF32(1.11),
NumpyArrayF32(3.14)], expected=4.25)
def testCannotCallWithDeletedBuffers(self):
c = self._NewComputation()
c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
arg = NumpyArrayF32(1.11)
compiled_c = c.Build().CompileWithExampleArguments([arg])
compiled_c = c.Build().Compile()
arg_buffer = xla_client.LocalBuffer.from_pyval(arg)
arg_buffer.delete()
with self.assertRaises(ValueError):
@ -452,8 +453,8 @@ class LocalBufferTest(ComputationTest):
np.testing.assert_equal(want, got)
def testDestructureTupleTwoArrayElementDifferentType(self):
t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
np.array([2, 3, 4, 5], dtype=np.int32))
t = (np.array([1.0, 2.0, 3.0, 4.0],
dtype=np.float32), np.array([2, 3, 4, 5], dtype=np.int32))
local_buffer = xla_client.LocalBuffer.from_pyval(t)
pieces = local_buffer.destructure()
self.assertTrue(local_buffer.is_deleted())
@ -486,7 +487,7 @@ class LocalBufferTest(ComputationTest):
pyval = np.array([[1., 2.]], np.float32)
local_buffer = xla_client.LocalBuffer.from_pyval(pyval)
xla_shape = local_buffer.shape()
self.assertEqual(xla_shape.dimensions(), (1, 2,))
self.assertEqual(xla_shape.dimensions(), (1, 2))
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
@ -500,18 +501,20 @@ class SingleOpTest(ComputationTest):
def testConcatenateF32(self):
c = self._NewComputation()
c.Concatenate(
(c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))),
dimension=0)
args = (
c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
c.Constant(NumpyArrayF32([4.0, 5.0, 6.0])),
)
c.Concatenate(args, dimension=0)
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
def testConcatenateF64(self):
c = self._NewComputation()
c.Concatenate(
(c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))),
dimension=0)
args = (
c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
c.Constant(NumpyArrayF64([4.0, 5.0, 6.0])),
)
c.Concatenate(args, dimension=0)
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
def testConvertElementType(self):
@ -665,11 +668,13 @@ class SingleOpTest(ComputationTest):
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 2, 3, 4)
rhs = a(1, 2, 1, 2) * 10
c.Conv(c.Constant(lhs), c.Constant(rhs),
[1, 1], xla_client.PaddingType.SAME)
result = np.array([[[[640., 700., 760., 300.],
[880., 940., 1000., 380.],
[1120., 1180., 1240., 460.]]]])
c.Conv(
c.Constant(lhs), c.Constant(rhs), [1, 1], xla_client.PaddingType.SAME)
result = np.array([[[
[640., 700., 760., 300.],
[880., 940., 1000., 380.],
[1120., 1180., 1240., 460.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvF32Valid(self):
@ -677,10 +682,12 @@ class SingleOpTest(ComputationTest):
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 2, 3, 4)
rhs = a(1, 2, 1, 2) * 10
c.Conv(c.Constant(lhs), c.Constant(rhs),
[2, 1], xla_client.PaddingType.VALID)
result = np.array([[[[640., 700., 760.],
[1120., 1180., 1240.]]]])
c.Conv(
c.Constant(lhs), c.Constant(rhs), [2, 1], xla_client.PaddingType.VALID)
result = np.array([[[
[640., 700., 760.],
[1120., 1180., 1240.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvWithGeneralPaddingF32(self):
@ -692,12 +699,15 @@ class SingleOpTest(ComputationTest):
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs),
strides, pads, lhs_dilation, rhs_dilation)
result = np.array([[[[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.]]]])
c.ConvWithGeneralPadding(
c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
rhs_dilation)
result = np.array([[[
[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvGeneralDilatedF32(self):
@ -710,13 +720,15 @@ class SingleOpTest(ComputationTest):
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
dimension_numbers = ("NCHW", "OIHW", "NCHW")
c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
strides, pads, lhs_dilation, rhs_dilation,
dimension_numbers)
result = np.array([[[[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.]]]])
c.ConvGeneralDilated(
c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
rhs_dilation, dimension_numbers)
result = np.array([[[
[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvGeneralDilatedPermutedF32(self):
@ -730,13 +742,10 @@ class SingleOpTest(ComputationTest):
rhs_dilation = (1, 1)
dimension_numbers = ("NHWC", "OIHW", "CWNH")
c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))),
c.Constant(rhs),
strides, pads, lhs_dilation, rhs_dilation,
dimension_numbers)
result = np.array([[[[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
c.ConvGeneralDilated(
c.Constant(np.transpose(lhs, (0, 2, 3, 1))), c.Constant(rhs), strides,
pads, lhs_dilation, rhs_dilation, dimension_numbers)
result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.],
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
@ -751,17 +760,20 @@ class SingleOpTest(ComputationTest):
rhs_dilation = (1, 1)
dimension_numbers = ("NCHW", "OIHW", "NCHW")
feature_group_count = 2
c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
strides, pads, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count)
result = np.array([[[[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.]],
[[0., 0., 0.],
[330., 380., 160.],
[0., 0., 0.],
[480., 530., 220.]]]])
c.ConvGeneralDilated(
c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count)
result = np.array([[[
[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.],
], [
[0., 0., 0.],
[330., 380., 160.],
[0., 0., 0.],
[480., 530., 220.],
]]])
self._ExecuteAndCompareClose(c, expected=result)
def testBooleanNot(self):
@ -952,14 +964,11 @@ class SingleOpTest(ComputationTest):
c = self._NewComputation()
c.Pad(
c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
c.Constant(NumpyArrayF32(0.0)),
[(1, 2, 1), (0, 1, 0)])
self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0],
[1.0, 2.0, 0.0],
[0.0, 0.0, 0.0],
[3.0, 4.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])
c.Constant(NumpyArrayF32(0.0)), [(1, 2, 1), (0, 1, 0)])
self._ExecuteAndCompareClose(
c,
expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
[3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
def testPadWithPaddingConfig(self):
c = self._NewComputation()
@ -972,14 +981,11 @@ class SingleOpTest(ComputationTest):
padding_config.dimensions.append(dimension)
c.Pad(
c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
c.Constant(NumpyArrayF32(0.0)),
padding_config)
self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0],
[1.0, 2.0, 0.0],
[0.0, 0.0, 0.0],
[3.0, 4.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])
c.Constant(NumpyArrayF32(0.0)), padding_config)
self._ExecuteAndCompareClose(
c,
expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
[3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
def testReshape(self):
c = self._NewComputation()
@ -1102,8 +1108,10 @@ class SingleOpTest(ComputationTest):
def testRngNormal(self):
shape = (2, 3)
c = self._NewComputation()
c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)),
dims=shape)
c.RngNormal(
c.Constant(NumpyArrayF32(0.)),
c.Constant(NumpyArrayF32(1.)),
dims=shape)
result = c.Build().Compile().ExecuteWithPythonValues()
# since the result is random, we just check shape and uniqueness
self.assertEqual(result.shape, shape)
@ -1113,8 +1121,10 @@ class SingleOpTest(ComputationTest):
lo, hi = 2., 4.
shape = (2, 3)
c = self._NewComputation()
c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)),
dims=shape)
c.RngUniform(
c.Constant(NumpyArrayF32(lo)),
c.Constant(NumpyArrayF32(hi)),
dims=shape)
result = c.Build().Compile().ExecuteWithPythonValues()
# since the result is random, we just check shape, uniqueness, and range
self.assertEqual(result.shape, shape)
@ -1126,8 +1136,10 @@ class SingleOpTest(ComputationTest):
lo, hi = 2, 4
shape = (2, 3)
c = self._NewComputation()
c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)),
dims=shape)
c.RngUniform(
c.Constant(NumpyArrayS32(lo)),
c.Constant(NumpyArrayS32(hi)),
dims=shape)
result = c.Build().Compile().ExecuteWithPythonValues()
# since the result is random, we just check shape, integrality, and range
self.assertEqual(result.shape, shape)
@ -1180,13 +1192,21 @@ class SingleOpTest(ComputationTest):
dtype=np.float32)
c = self._NewComputation()
c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False,
lower=True, transpose_a=True)
self._ExecuteAndCompareClose(c, expected=np.array([
[0.5, 0.08333334, 0.04629629, 0.03367003],
[2.5, -0.25, -0.1388889, -0.1010101],
[4.5, -0.58333331, -0.32407406, -0.23569024],
], dtype=np.float32), rtol=1e-4)
c.TriangularSolve(
c.Constant(a_vals),
c.Constant(b_vals),
left_side=False,
lower=True,
transpose_a=True)
self._ExecuteAndCompareClose(
c,
expected=np.array([
[0.5, 0.08333334, 0.04629629, 0.03367003],
[2.5, -0.25, -0.1388889, -0.1010101],
[4.5, -0.58333331, -0.32407406, -0.23569024],
],
dtype=np.float32),
rtol=1e-4)
def testIsConstant(self):
c = self._NewComputation()
@ -1261,8 +1281,9 @@ class EmbeddedComputationsTest(ComputationTest):
def _CreateMulF32ByParamComputation(self):
"""Computation (f32) -> f32 that multiplies one parameter by the other."""
c = self._NewComputation("mul_f32_by_param")
c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
c.Mul(
c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
return c.Build()
def _CreateMulF64By2Computation(self):
@ -1326,15 +1347,17 @@ class EmbeddedComputationsTest(ComputationTest):
def _CreateBinaryGeF32Computation(self):
"""Computation (f32, f32) -> bool that tests first_param >= second_param."""
c = self._NewComputation("param0_lt_param1")
c.Ge(c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
c.Ge(
c.ParameterFromNumpy(NumpyArrayF32(0)),
c.ParameterFromNumpy(NumpyArrayF32(0)))
return c.Build()
def _CreateBinaryGeF64Computation(self):
"""Computation (f64, f64) -> bool that tests first_param >= second_param."""
c = self._NewComputation("param0_lt_param1")
c.Ge(c.ParameterFromNumpy(NumpyArrayF64(0)),
c.ParameterFromNumpy(NumpyArrayF64(0)))
c.Ge(
c.ParameterFromNumpy(NumpyArrayF64(0)),
c.ParameterFromNumpy(NumpyArrayF64(0)))
return c.Build()
def _MakeSample3DArrayF32(self):
@ -1415,26 +1438,28 @@ class EmbeddedComputationsTest(ComputationTest):
def testSelectAndScatterF32(self):
c = self._NewComputation()
c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
select=self._CreateBinaryGeF32Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID,
source=c.Constant(NumpyArrayF32([[0.1, 0.2]])),
init_value=c.Constant(NumpyArrayF32(1)),
scatter=self._CreateBinaryAddF32Computation())
c.SelectAndScatter(
c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
select=self._CreateBinaryGeF32Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID,
source=c.Constant(NumpyArrayF32([[0.1, 0.2]])),
init_value=c.Constant(NumpyArrayF32(1)),
scatter=self._CreateBinaryAddF32Computation())
self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
def testSelectAndScatterF64(self):
c = self._NewComputation()
c.SelectAndScatter(c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])),
select=self._CreateBinaryGeF64Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID,
source=c.Constant(NumpyArrayF64([[0.1, 0.2]])),
init_value=c.Constant(NumpyArrayF64(1)),
scatter=self._CreateBinaryAddF64Computation())
c.SelectAndScatter(
c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])),
select=self._CreateBinaryGeF64Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID,
source=c.Constant(NumpyArrayF64([[0.1, 0.2]])),
init_value=c.Constant(NumpyArrayF64(1)),
scatter=self._CreateBinaryAddF64Computation())
self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
def testReduce1DtoScalarF32(self):
@ -1537,61 +1562,73 @@ class EmbeddedComputationsTest(ComputationTest):
def testReduceWindowValidUnitStridesF32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
window_dimensions=(2, 1), window_strides=(1, 1),
padding=xla_client.PaddingType.VALID)
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
window_dimensions=(2, 1),
window_strides=(1, 1),
padding=xla_client.PaddingType.VALID)
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
def testReduceWindowSameUnitStridesF32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
window_dimensions=(2, 1), window_strides=(1, 1),
padding=xla_client.PaddingType.SAME)
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
window_dimensions=(2, 1),
window_strides=(1, 1),
padding=xla_client.PaddingType.SAME)
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
def testReduceWindowValidGeneralStridesF32(self):
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
window_dimensions=(2, 1), window_strides=(1, 2),
padding=xla_client.PaddingType.VALID)
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF32Scalar(0),
computation_to_apply=self._CreateBinaryAddF32Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID)
self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
def testReduceWindowValidUnitStridesF64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
window_dimensions=(2, 1), window_strides=(1, 1),
padding=xla_client.PaddingType.VALID)
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
window_dimensions=(2, 1),
window_strides=(1, 1),
padding=xla_client.PaddingType.VALID)
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
def testReduceWindowSameUnitStridesF64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
window_dimensions=(2, 1), window_strides=(1, 1),
padding=xla_client.PaddingType.SAME)
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
window_dimensions=(2, 1),
window_strides=(1, 1),
padding=xla_client.PaddingType.SAME)
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
def testReduceWindowValidGeneralStridesF64(self):
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
c = self._NewComputation()
c.ReduceWindow(operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
window_dimensions=(2, 1), window_strides=(1, 2),
padding=xla_client.PaddingType.VALID)
c.ReduceWindow(
operand=c.Constant(input_array),
init_value=c.ConstantF64Scalar(0),
computation_to_apply=self._CreateBinaryAddF64Computation(),
window_dimensions=(2, 1),
window_strides=(1, 2),
padding=xla_client.PaddingType.VALID)
self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
def testWhileF32(self):
@ -1636,7 +1673,7 @@ class EmbeddedComputationsTest(ComputationTest):
to_infeed = NumpyArrayS32([1, 2, 3, 4])
c = self._NewComputation()
c.Infeed(xla_client.Shape.from_pyval(to_infeed[0]))
compiled_c = c.Build().CompileWithExampleArguments()
compiled_c = c.Build().Compile()
for item in to_infeed:
xla_client.transfer_to_infeed(item)
@ -1650,7 +1687,7 @@ class EmbeddedComputationsTest(ComputationTest):
x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0]))
c.Outfeed(x)
compiled_c = c.Build().CompileWithExampleArguments()
compiled_c = c.Build().Compile()
for want in to_round_trip:
execution = threading.Thread(target=compiled_c.Execute)
@ -1673,8 +1710,9 @@ class EmbeddedComputationsTest(ComputationTest):
dnums.index_vector_dim = 1
c = self._NewComputation()
c.Scatter(c.Constant(a), c.Constant(scatter_indices), c.Constant(updates),
self._CreateBinaryAddS32Computation(), dnums)
c.Scatter(
c.Constant(a), c.Constant(scatter_indices), c.Constant(updates),
self._CreateBinaryAddS32Computation(), dnums)
expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32)
self._ExecuteAndCompareClose(c, expected=expected)
@ -1685,15 +1723,34 @@ class ErrorTest(ComputationTest):
self.f32_scalar_2 = NumpyArrayF32(2.0)
self.s32_scalar_2 = NumpyArrayS32(2)
def testCompileWithWrongElementTypeInLayout(self):
c = self._NewComputation()
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
c.ParameterFromNumpy(self.s32_scalar_2)
c.ClearOpMetadata()
options = xla_client.CompileOptions()
options.argument_layouts = [xla_client.Shape.array_shape(np.float32, [])]
def TestFun():
return c.Build().Compile(compile_options=options)
self.assertRaisesRegexp(
RuntimeError, r".*Invalid argument shape.*"
r"expected s32\[\], got f32\[\].*", TestFun)
def testInvokeWithWrongElementType(self):
c = self._NewComputation()
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
c.ParameterFromNumpy(self.s32_scalar_2)
c.ClearOpMetadata()
def TestFun():
return c.Build().Compile().ExecuteWithPythonValues([self.f32_scalar_2])
self.assertRaisesRegexp(
RuntimeError, r"Invalid argument shape.*xla_client_test.py.*"
r"expected s32\[\], got f32\[\]",
lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2]))
RuntimeError, r"Invalid argument: Argument does not match.*"
r"want s32\[\], got f32\[\].*", TestFun)
class ComputationRootTest(ComputationTest):
@ -1706,7 +1763,7 @@ class ComputationRootTest(ComputationTest):
extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable
arg = NumpyArrayF32(1.0)
compiled_c = c.Build(result).CompileWithExampleArguments([arg])
compiled_c = c.Build(result).Compile()
ans = compiled_c.ExecuteWithPythonValues([arg])
np.testing.assert_allclose(ans, 4.14)

View File

@ -77,14 +77,13 @@ class XrtBackend(xla_client.Backend):
def destructure_tuple(self, c_buffer):
return c_buffer.DestructureTuple()
def compile(self, computation, arg_shapes, result_shape, compile_options):
del arg_shapes
del result_shape
def compile(self, computation, compile_options):
# pylint: disable=protected-access
program_shape = xla_client._wrap_program_shape(
computation.GetProgramShape())
# pylint: enable=protected-access
proto = computation.GetSerializedProto()
# TODO(phawkins): use the layouts in compile_options.
arg_shapes = [
_make_xla_shape(shape.with_major_to_minor_layout_if_absent())
for shape in program_shape.parameter_shapes

View File

@ -2903,6 +2903,7 @@ cc_library(
"hlo_pass_pipeline.h",
],
deps = [
":compilation_stats",
":dump",
":hlo",
":hlo_graph_dumper",
@ -2913,7 +2914,6 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
@ -3914,6 +3914,19 @@ cc_library(
],
)
cc_library(
name = "compilation_stats",
srcs = ["compilation_stats.cc"],
hdrs = ["compilation_stats.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
],
)
cc_library(
name = "dynamic_index_splitter",
srcs = ["dynamic_index_splitter.cc"],

View File

@ -66,38 +66,16 @@ const absl::optional<std::set<int>>& BackendOptions::allowed_devices() const {
return allowed_devices_;
}
namespace {
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
public:
explicit EigenThreadPoolWrapper(tensorflow::thread::ThreadPool* pool)
: pool_(pool) {}
~EigenThreadPoolWrapper() override {}
void Schedule(std::function<void()> fn) override {
pool_->Schedule(std::move(fn));
}
int NumThreads() const override { return pool_->NumThreads(); }
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
private:
tensorflow::thread::ThreadPool* pool_ = nullptr;
};
} // namespace
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct Backend::IntraOpThreadPool {
explicit IntraOpThreadPool(const int num_threads)
: pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(),
"XLAEigen", num_threads)),
wrapper(new EigenThreadPoolWrapper(pool.get())),
device(new Eigen::ThreadPoolDevice(wrapper.get(),
wrapper->NumThreads())) {}
device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
pool->NumThreads())) {}
std::unique_ptr<tensorflow::thread::ThreadPool> pool;
std::unique_ptr<EigenThreadPoolWrapper> wrapper;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};

View File

@ -23,6 +23,24 @@ namespace xla {
StatusOr<bool>
BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
HloInstruction* batch_dot) {
// This pass assumes the lhs and rhs batch dimensions are equal and strictly
// ascending.
const auto& is_iota = [](absl::Span<const int64> dims) {
for (int64 i = 0; i < dims.size(); ++i) {
if (dims[i] != i) {
return false;
}
}
return true;
};
if (!absl::c_equal(
batch_dot->dot_dimension_numbers().lhs_batch_dimensions(),
batch_dot->dot_dimension_numbers().rhs_batch_dimensions()) ||
!is_iota(AsInt64Slice(
batch_dot->dot_dimension_numbers().lhs_batch_dimensions()))) {
return false;
}
const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers();
HloInstruction *lhs = batch_dot->mutable_operand(0),
*rhs = batch_dot->mutable_operand(1);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@ -30,22 +31,15 @@ namespace xla {
class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
public:
explicit BFloat16NormalizationVisitor(
HloComputation* computation, const BFloat16Support* bfloat16_support,
const BFloat16Support* bfloat16_support,
BFloat16Normalization* bfloat16_normalization)
: computation_(computation),
: computation_(nullptr),
bfloat16_support_(bfloat16_support),
bfloat16_normalization_(bfloat16_normalization) {}
bool changed() const { return changed_; }
Status DefaultAction(HloInstruction* hlo) override;
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support,
BFloat16Normalization* bfloat16_normalization) {
BFloat16NormalizationVisitor visitor(computation, bfloat16_support,
bfloat16_normalization);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
Status Preprocess(HloInstruction* hlo) override;
private:
// Checks if the HLO uses BF16 in an unsupported way, and if so, inserts
@ -408,18 +402,21 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
return HandleInstruction(hlo);
}
Status BFloat16NormalizationVisitor::Preprocess(HloInstruction* hlo) {
computation_ = hlo->parent();
return Status::OK();
}
StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
XLA_VLOG_LINES(
2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
bool changed = false;
BFloat16NormalizationVisitor visitor(bfloat16_support_, this);
for (auto* comp : module->MakeComputationPostOrder()) {
if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_, this)) {
changed = true;
}
TF_RETURN_IF_ERROR(comp->Accept(&visitor));
}
XLA_VLOG_LINES(2,
"BFloat16Normalization::Run(), after:\n" + module->ToString());
return changed;
return visitor.changed();
}
} // namespace xla

View File

@ -105,6 +105,8 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
return operand_index == 0;
case HloOpcode::kDynamicUpdateSlice:
return operand_index == 0 || operand_index == 1;
case HloOpcode::kGather:
return operand_index == 0;
case HloOpcode::kSelect:
case HloOpcode::kTupleSelect:
return operand_index == 1 || operand_index == 2;

View File

@ -277,8 +277,8 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
// Constructor for CallGraph is private so absl::make_unique can't be used.
auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
VLOG(2) << "Building call graph for:";
XLA_VLOG_LINES(2, module->ToString());
VLOG(3) << "Building call graph for:";
XLA_VLOG_LINES(3, module->ToString());
// Construct nodes of the call graph and populate the callsites.
for (HloComputation* computation : module->computations()) {
@ -309,7 +309,7 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
call_graph->SetCallContexts();
call_graph->SetNodeDepths();
XLA_VLOG_LINES(1, call_graph->ToString());
XLA_VLOG_LINES(2, call_graph->ToString());
return call_graph;
}

View File

@ -0,0 +1,132 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/compilation_stats.h"
#include <iostream>
#include <memory>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/env.h"
namespace xla {
class NoopStats : public CompilationStats {
public:
NoopStats() = default;
void StartPass(absl::string_view pass_name) override {}
void EndPass(absl::string_view pass_name) override {}
void CompilationReport() override {}
};
class Stats : public CompilationStats {
public:
Stats() = default;
void StartPass(absl::string_view pass_name) override;
void EndPass(absl::string_view pass_name) override;
void CompilationReport() override;
private:
struct PassInfo {
PassInfo(absl::string_view name, double duration)
: name(name), duration_ms(duration) {}
absl::string_view name;
int num_runs = 1;
double duration_ms;
};
// Info about the passes that have been run so far.
std::vector<PassInfo> passes_;
// Used to avoid nested calls to StartPass.
bool pass_running_ = false;
absl::string_view current_pass_;
// The start time of the currently running pass.
uint64 start_micros_;
};
/* static */
std::unique_ptr<CompilationStats> CompilationStats::MakeNoopStats() {
return absl::make_unique<NoopStats>();
}
/* static */
std::unique_ptr<CompilationStats> CompilationStats::MakeStats() {
return absl::make_unique<Stats>();
}
void Stats::StartPass(absl::string_view pass_name) {
CHECK(!pass_running_) << "Can't start " << pass_name << " while running "
<< current_pass_;
pass_running_ = true;
current_pass_ = pass_name;
start_micros_ = tensorflow::Env::Default()->NowMicros();
}
void Stats::EndPass(absl::string_view pass_name) {
CHECK(pass_running_);
CHECK_EQ(current_pass_, pass_name);
pass_running_ = false;
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
double duration_ms = (end_micros - start_micros_) / 1000.0;
passes_.push_back(PassInfo(current_pass_, duration_ms));
}
void Stats::CompilationReport() {
CHECK(!pass_running_) << "EndPass never called for " << current_pass_;
absl::flat_hash_map<absl::string_view, PassInfo> summary;
double total_duration = 0;
for (auto& pass_run : passes_) {
auto pass_name = pass_run.name;
total_duration += pass_run.duration_ms;
auto it = summary.find(pass_name);
if (it == summary.end()) {
summary.insert(std::make_pair(pass_name, pass_run));
} else {
++summary.at(pass_name).num_runs;
summary.at(pass_name).duration_ms += pass_run.duration_ms;
}
}
std::vector<PassInfo> sorted_summary;
sorted_summary.reserve(summary.size());
for (auto& it : summary) {
sorted_summary.push_back(it.second);
}
absl::c_sort(sorted_summary, [](const PassInfo& a, const PassInfo& b) {
// Sort passes that take the longest first, break ties using pass names.
return std::make_pair(b.duration_ms, a.name) <
std::make_pair(a.duration_ms, b.name);
});
LOG(INFO) << "Total runtime (ms) of HLO passes: " << total_duration;
LOG(INFO) << "Pass name, num runs, time (ms)";
for (auto& pass_info : sorted_summary) {
LOG(INFO) << pass_info.name << ", " << pass_info.num_runs << ", "
<< pass_info.duration_ms;
}
}
} // namespace xla

View File

@ -0,0 +1,48 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_STATS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_STATS_H_
#include <memory>
#include <string>
#include "absl/strings/str_format.h"
namespace xla {
// This class is used to collect information about HLO passes and print some
// statistics at the end of compilation. From HloPassPipeline, we call StartPass
// before the execution of a pass, and EndPass after. Currently, we only collect
// timing information and how many times each pass was run. In the future, we
// can add more things, such as the size of the HLO graph after each pass.
class CompilationStats {
public:
virtual ~CompilationStats() = default;
static std::unique_ptr<CompilationStats> MakeNoopStats();
static std::unique_ptr<CompilationStats> MakeStats();
virtual void StartPass(absl::string_view pass_name) = 0;
virtual void EndPass(absl::string_view pass_name) = 0;
virtual void CompilationReport() = 0;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_STATS_H_

View File

@ -209,7 +209,6 @@ void CompilerFunctor::AddOptimizationPasses(
builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
}
builder.DisableUnitAtATime = false;
builder.DisableUnrollLoops = opt_level == 0;
builder.LoopVectorize = opt_level > 0 && size_level == 0;
builder.SLPVectorize = opt_level > 1 && size_level == 0;

View File

@ -275,6 +275,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<CallInliner>();
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
// After canonicalization, there may be more batch dots that can be
// simplified.
pipeline.AddPass<BatchDotSimplification>();
auto cost_model = [](HloInstruction* conv) {
// We need a cost model for CPUs. Currently, do nothing.
return false;

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@ -101,8 +100,7 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
} else {
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
2);
tensorflow::EigenThreadPoolWrapper tp(&pool);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
ExecutableRunOptions run_options;
run_options.set_intra_op_thread_pool(&device);

View File

@ -291,6 +291,7 @@ class DfsHloVisitorBase {
// This call is purely a performance hint and can be omitted without
// affecting correctness.
void ReserveVisitStates(int num) { visit_state_.reserve(num); }
size_t VisitStateSize() const { return visit_state_.size(); }
// Useful when we want to visit the same computation more than once with the
// same visitor.

View File

@ -394,6 +394,7 @@ cc_library(
":outfeed_manager",
":partition_assignment",
":stream_assignment",
":stream_executor_util",
":thunk",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
@ -418,6 +419,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:gpu_utils",
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
"//tensorflow/core/platform/default/build_config:cufft_plugin",
@ -467,7 +469,9 @@ cc_library(
":gpu_executable",
":ir_emission_utils",
":redzone_allocator",
":stream_executor_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
@ -478,7 +482,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:logger",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/util/proto:proto_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
#include "google/protobuf/any.pb.h"
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/time/time.h"
@ -26,7 +28,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logger.h"
@ -76,28 +81,6 @@ string NumBytesToString(int64 bytes) {
bytes, "B)");
}
// Acquires a process-global lock on the device pointed to by the given
// StreamExecutor.
//
// This is used to prevent other XLA instances from trying to autotune on this
// device while we're using it.
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
// se::Platform*s are global singletons guaranteed to live forever.
static auto* mutexes =
new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
tensorflow::mutex>();
tensorflow::mutex_lock global_lock(mu);
auto it = mutexes
->emplace(std::piecewise_construct,
std::make_tuple(stream_exec->platform(),
stream_exec->device_ordinal()),
std::make_tuple())
.first;
return tensorflow::mutex_lock{it->second};
}
tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
tensorflow::CudnnVersion cudnn_version;
if (auto* dnn = stream_executor->AsDnn()) {
@ -179,33 +162,88 @@ bool CheckRedzones(const RedzoneAllocator& allocator, se::Stream* stream,
return false;
}
using ConvCacheKey =
std::tuple<se::StreamExecutor*, std::string, std::string, Shape,
std::vector<Shape>, std::string, std::string, int64>;
struct ConvCacheStats {
int64 cache_hits = 0;
int64 cache_misses = 0;
void LogStats() {
VLOG(1) << "Cache hits: " << cache_hits;
VLOG(1) << "Cache misses: " << cache_misses;
}
};
StatusOr<ConvCacheKey> AutotuneCacheKeyfromInstruction(
const HloCustomCallInstruction* conv, se::StreamExecutor* se) {
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
conv->backend_config<CudnnConvBackendConfig>());
std::vector<Shape> operand_shapes;
absl::c_transform(conv->operands(), std::back_inserter(operand_shapes),
[&](const HloInstruction* op) { return op->shape(); });
return std::make_tuple(
se, backend_config.SerializeAsString(), conv->custom_call_target(),
conv->shape(), std::move(operand_shapes),
conv->window().SerializeAsString(),
conv->convolution_dimension_numbers().SerializeAsString(),
conv->feature_group_count());
}
tensorflow::mutex autotune_cache_lock(tensorflow::LINKER_INITIALIZED);
auto& autotune_cache GUARDED_BY(autotune_cache_lock) =
*new absl::flat_hash_map<ConvCacheKey, AutotuneResult>();
auto& autotune_cache_stats GUARDED_BY(autotune_cache_lock) =
*new ConvCacheStats();
} // anonymous namespace
// We could have caching here so that we don't redo this work for two identical
// convolutions. Unfortunately our cache key would have to be a tuple
// containing the protos passed to this function, and we have no utility for
// hashing protos. We could write our own hash functions, but they'd silently
// break if we ever added a field to one of the protos. Perhaps we could hack
// using the binary-encoded proto as the hash key, on the assumption that two
// protos being binary-equal is a sufficient, if not necessary, condition for
// proper equality. But that would still leave us open to having unnecessary
// cache misses and doing extra work. Overall, caching doesn't seem worth the
// trouble, but we may want to revisit this if we ever find a model where
// caching would speed up compilation a lot.
StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
const HloCustomCallInstruction* instr) {
XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
"CudnnConvAlgorithmPicker::PickBestAlgorithm for ", instr->ToString()));
const Shape& result_shape = instr->shape().tuple_shapes(0);
// Don't run this function concurrently on the same GPU.
//
// This is a bit of a hack and doesn't protect us against arbitrary concurrent
// use of a GPU, but it's sufficient to let us compile two HLO modules
// concurrently and then run them sequentially.
//
// Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
// avoid ever doing duplicate work. If we have a cache miss, only one thread
// will run PickBestAlgorithmImpl for a particular device.
tensorflow::mutex_lock lock = LockGpu(stream_exec_);
// We cache the autotuning results to avoid doing the duplicate work,
// which can greatly improve both stability (deterministic numeric results
// within a process for a given input) and performance (2x speedup on some
// models).
TF_ASSIGN_OR_RETURN(ConvCacheKey key,
AutotuneCacheKeyfromInstruction(instr, stream_exec_));
{
tensorflow::mutex_lock lock(autotune_cache_lock);
auto it = autotune_cache.find(key);
if (it != autotune_cache.end()) {
autotune_cache_stats.cache_hits++;
return it->second;
}
autotune_cache_stats.cache_misses++;
}
StatusOr<AutotuneResult> result_or = PickBestAlgorithmNoCache(instr);
if (result_or.ok()) {
tensorflow::mutex_lock lock(autotune_cache_lock);
CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
}
return result_or;
}
StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
const HloCustomCallInstruction* instr) {
XLA_SCOPED_LOGGING_TIMER(
absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithmImpl for ",
instr->ToString()));
const Shape& result_shape = instr->shape().tuple_shapes(0);
// Make sure any previous activity on this executor is done. We don't want to
// interfere with programs that are still running on the GPU.
if (!stream_exec_->SynchronizeAllActivity()) {
@ -543,6 +581,12 @@ StatusOr<bool> CudnnConvAlgorithmPicker::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
changed |= result;
}
{
tensorflow::mutex_lock lock(autotune_cache_lock);
autotune_cache_stats.LogStats();
}
return changed;
}

View File

@ -52,6 +52,8 @@ class CudnnConvAlgorithmPicker : public HloModulePass {
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithm(
const HloCustomCallInstruction* instr);
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCache(
const HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null

View File

@ -364,7 +364,6 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
params.output_buf = operand_buffers[1];
break;
case CudnnConvKind::kForwardActivation: {
params.kind = CudnnConvKind::kForwardActivation;
params.input_shape = &lhs_shape;
params.filter_shape = &rhs_shape;
params.output_shape = &conv_result_shape;

View File

@ -201,8 +201,8 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
HloModule m
f1_computation {
f1_p0 = f32[10]{0} parameter(0)
ROOT f1_root = f32[10]{0} add(f1_p0, f1_p0)
f1_p0 = f32[32]{0} parameter(0)
ROOT f1_root = f32[32]{0} add(f1_p0, f1_p0)
}
add_computation {
@ -212,16 +212,16 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
}
f2_computation {
f2_p0 = f32[10]{0} parameter(0)
f2_mul = f32[10]{0} multiply(f2_p0, f2_p0)
f2_p0 = f32[32]{0} parameter(0)
f2_mul = f32[32]{0} multiply(f2_p0, f2_p0)
f2_zero = f32[] constant(0)
ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0},
to_apply=add_computation
}
ENTRY entry {
p0 = f32[10]{0} parameter(0)
f1 = f32[10]{0} fusion(p0), kind=kLoop, calls=f1_computation
p0 = f32[32]{0} parameter(0)
f1 = f32[32]{0} fusion(p0), kind=kLoop, calls=f1_computation
ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
})")
.ValueOrDie();

View File

@ -17,68 +17,42 @@ limitations under the License.
#include <functional>
#include "absl/strings/str_cat.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory.h"
namespace xla {
namespace gpu {
namespace {
using MatrixDescriptor = gemm_thunk_internal::MatrixDescriptor;
// Performs a gemm call without an explicit algorithm on lhs_matrix and
// rhs_matrix, and stores the result to output_matrix.
template <typename Element>
bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, double alpha, double beta,
se::Stream* stream) {
DCHECK(!output_matrix.transpose);
// This struct contains the metadata of a matrix, e.g., its base address and
// dimensions.
struct MatrixDescriptor {
se::DeviceMemoryBase data;
bool transpose; // Whether this matrix needs to be transposed.
int64 num_rows;
int64 num_cols;
};
const int64 batch_size = lhs_matrix.batch_size;
CHECK_EQ(batch_size, rhs_matrix.batch_size);
CHECK_EQ(batch_size, output_matrix.batch_size);
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
using GemmCacheKey =
std::tuple<PrimitiveType, bool, int64, int64, bool, int64, int64, double,
double, se::blas::ComputationType, se::StreamExecutor*>;
auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
: se::blas::Transpose::kNoTranspose;
auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
: se::blas::Transpose::kNoTranspose;
auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
tensorflow::mutex autotune_cache_mu(tensorflow::LINKER_INITIALIZED);
auto& autotune_cache GUARDED_BY(autotune_cache_mu) = *new absl::flat_hash_map<
GemmCacheKey, absl::optional<se::blas::AlgorithmType>>();
int64 cache_hits GUARDED_BY(autotune_cache_mu) = 0;
int64 cache_misses GUARDED_BY(autotune_cache_mu) = 0;
if (batch_size == 1) {
return stream
->ThenBlasGemm(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
&output_data, /*leading dim of output=*/output_matrix.num_rows)
.ok();
}
int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
return stream
->ThenBlasGemmStridedBatched(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k,
/*alpha=*/alpha, lhs_data,
/*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
/*beta=*/beta, &output_data,
/*leading dim of output=*/output_matrix.num_rows, output_stride,
batch_size)
.ok();
}
// Like DoGemm, but takes an explicit computation type and algorithm.
// Performs a gemm call on lhs_matrix and rhs_matrix, and stores the result
// to output_matrix.
//
// computation_type specifies the type of intermediate values generated during
// the matmul (e.g. your input/output matricies could be f16s but you could do
// computations with f32s). algorithm is an opaque identifier which functions
@ -93,19 +67,16 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
// the Stream was valid to begin with); check the is_valid property of the
// ProfileResult to see whether the call actually succeeded.
template <typename Element>
bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
bool DoGemmWithAlgorithm(int64 batch_size, MatrixDescriptor lhs_matrix,
MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, double alpha,
double beta,
se::blas::ComputationType computation_type,
se::blas::AlgorithmType algorithm, se::Stream* stream,
se::Stream* stream,
absl::optional<se::blas::AlgorithmType> algorithm,
se::blas::ProfileResult* output_profile_result) {
DCHECK(!output_matrix.transpose);
CHECK_EQ(1, lhs_matrix.batch_size);
CHECK_EQ(1, rhs_matrix.batch_size);
CHECK_EQ(1, output_matrix.batch_size);
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
@ -116,17 +87,45 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
: se::blas::Transpose::kNoTranspose;
auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
return stream
->ThenBlasGemmWithAlgorithm(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k,
/*alpha=*/static_cast<Element>(alpha), lhs_data,
/*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows,
/*beta=*/static_cast<Element>(beta), &output_data,
/*leading dim of output=*/output_matrix.num_rows, computation_type,
algorithm, output_profile_result)
.ok();
if (algorithm) {
// Autotuning is disabled for batch_size != 1.
CHECK_EQ(1, batch_size);
return stream
->ThenBlasGemmWithAlgorithm(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k,
/*alpha=*/static_cast<Element>(alpha), lhs_data,
/*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows,
/*beta=*/static_cast<Element>(beta), &output_data,
/*leading dim of output=*/output_matrix.num_rows, computation_type,
*algorithm, output_profile_result)
.ok();
} else if (batch_size == 1) {
return stream
->ThenBlasGemm(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
&output_data, /*leading dim of output=*/output_matrix.num_rows)
.ok();
} else {
int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
return stream
->ThenBlasGemmStridedBatched(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k,
/*alpha=*/alpha, lhs_data,
/*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
/*beta=*/beta, &output_data,
/*leading dim of output=*/output_matrix.num_rows, output_stride,
batch_size)
.ok();
}
}
// Experimentally tries to pick the best algorithm for the given gemm.
@ -136,10 +135,43 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
// than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
// all.
template <typename Element>
StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, double alpha, double beta,
se::blas::ComputationType computation_type, se::Stream* stream) {
absl::optional<se::blas::AlgorithmType> DoUncachedGemmAutotune(
PrimitiveType type, se::blas::ComputationType computation_type,
int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, se::Stream* stream,
const Shape output_shape, double alpha, double beta,
absl::string_view instr_descr) {
if (!stream->BlockHostUntilDone().ok()) {
VLOG(2) << "Failed to synchronize GPU for autotuning";
return absl::nullopt;
}
VLOG(3) << "Starting autotune of GemmThunk " << instr_descr;
// If the output buffer already contains a bias then autotune into a
// scratch buffer. This avoids overwriting the bias buffer. The scratch
// buffer may contain arbitrary garbage values.
se::DeviceMemoryBase scratch_data = output_matrix.data;
absl::optional<se::ScopedDeviceMemory<char>> allocated_memory;
if (beta != 0.0) {
se::DeviceMemory<char> out = stream->parent()->AllocateArray<char>(
ShapeUtil::ByteSizeOf(output_shape));
if (out.is_null()) {
VLOG(1) << "Allocation failed, using generic algorthm";
return absl::nullopt;
}
// Destructor ensures deallocation at the end of the scope.
allocated_memory.emplace(stream->parent(), out);
scratch_data = out;
}
const MatrixDescriptor scratch_descriptor{scratch_data,
/*needs_transpose=*/false,
output_matrix.num_rows,
output_matrix.num_cols};
std::vector<se::blas::AlgorithmType> algorithms;
CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
@ -150,9 +182,9 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
// for all algorithms if we're targeting < sm_50. But because we pass a
// non-null ProfileResult, DoGemmWithAlgorithm should always return true,
// and the actual success-ness is returned in ProfileResult::is_valid.
CHECK(DoGemmWithAlgorithm<Element>(lhs_matrix, rhs_matrix, output_matrix,
alpha, beta, computation_type, algorithm,
stream, &profile_result));
CHECK(DoGemmWithAlgorithm<Element>(
/*batch_size=*/1, lhs_matrix, rhs_matrix, scratch_descriptor, alpha,
beta, computation_type, stream, algorithm, &profile_result));
if (profile_result.is_valid()) {
VLOG(3) << "cublas gemm algorithm " << algorithm << " took "
@ -167,65 +199,14 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
}
if (best_result.is_valid()) {
VLOG(2) << "Autotune on GemmThunk " << instr_descr
<< " successful; best algorithm is " << best_result.algorithm();
return best_result.algorithm();
}
return InternalError(
"Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms "
"ran successfully",
stream, algorithms.size());
}
// Helper functions to go from a PrimitiveType to a templated version of
// DoGemm/DoGemmWithAlgorithm/DoGemmAutotune.
auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
switch (type) {
case F16:
return &DoGemm<Eigen::half>;
case F32:
return &DoGemm<float>;
case F64:
return &DoGemm<double>;
case C64:
return &DoGemm<std::complex<float>>;
case C128:
return &DoGemm<std::complex<double>>;
default:
LOG(FATAL) << "Unsupported type.";
}
}
auto GetGemmWithAlgorithmFn(PrimitiveType type)
-> decltype(&DoGemmWithAlgorithm<float>) {
switch (type) {
case F16:
return &DoGemmWithAlgorithm<Eigen::half>;
case F32:
return &DoGemmWithAlgorithm<float>;
case F64:
return &DoGemmWithAlgorithm<double>;
case C64:
return &DoGemmWithAlgorithm<std::complex<float>>;
case C128:
return &DoGemmWithAlgorithm<std::complex<double>>;
default:
LOG(FATAL) << "Unsupported type.";
}
}
auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
switch (type) {
case F16:
return &DoGemmAutotune<Eigen::half>;
case F32:
return &DoGemmAutotune<float>;
case F64:
return &DoGemmAutotune<double>;
case C64:
return &DoGemmAutotune<std::complex<float>>;
case C128:
return &DoGemmAutotune<std::complex<double>>;
default:
LOG(FATAL) << "Unsupported type.";
}
VLOG(1) << "Unable to autotune cuBLAS gemm on stream " << stream
<< " none of the " << algorithms.size() << " ran successfully";
return absl::nullopt;
}
// Converts from an XLA PrimitiveType to a blas::ComputationType, which is used
@ -250,6 +231,48 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
}
}
template <typename Element>
absl::optional<se::blas::AlgorithmType> DoGemmAutotune(
int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, se::Stream* stream,
const Shape output_shape, double alpha, double beta,
absl::string_view instr_descr) {
PrimitiveType type = output_shape.element_type();
se::blas::ComputationType computation_type = GetBlasComputationType(type);
tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent());
GemmCacheKey key = std::make_tuple(
type, lhs_matrix.transpose, lhs_matrix.num_rows, lhs_matrix.num_cols,
rhs_matrix.transpose, rhs_matrix.num_rows, rhs_matrix.num_cols, alpha,
beta, computation_type, stream->parent());
tensorflow::mutex_lock cache_lock(autotune_cache_mu);
auto it = autotune_cache.find(key);
int64 autotuning_requests = cache_hits + cache_misses;
if (autotuning_requests && autotuning_requests % 10 == 0) {
VLOG(2) << "Autotuning cache hits/(hits + misses): " << cache_hits << "/"
<< autotuning_requests;
}
if (it != autotune_cache.end()) {
cache_hits++;
VLOG(4)
<< "Autotuning cache hit, using algorithm (-1 stands for 'generic'): "
<< it->second.value_or(-1);
return it->second;
}
cache_misses++;
VLOG(4) << "Autotuning cache miss";
auto result = DoUncachedGemmAutotune<Element>(
type, computation_type, batch_size, lhs_matrix, rhs_matrix, output_matrix,
stream, output_shape, alpha, beta, instr_descr);
CHECK(autotune_cache.emplace(key, result).second);
return result;
}
DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
if (hlo_instruction.opcode() == HloOpcode::kDot) {
return hlo_instruction.dot_dimension_numbers();
@ -270,6 +293,138 @@ DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
return dot->dot_dimension_numbers();
}
template <typename Element>
Status ExecuteOnStreamParameterized(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler, const BufferAllocation::Slice lhs_buffer,
const BufferAllocation::Slice rhs_buffer,
const BufferAllocation::Slice output_buffer, const Shape lhs_shape,
const Shape rhs_shape, const Shape output_shape,
bool implements_whole_instruction, const HloInstruction* hlo_instruction,
double alpha, double beta, bool xla_gpu_disable_autotune) {
VLOG(2) << "Executing a GemmThunk";
se::DeviceMemoryBase lhs_data =
buffer_allocations.GetDeviceAddress(lhs_buffer);
se::DeviceMemoryBase rhs_data =
buffer_allocations.GetDeviceAddress(rhs_buffer);
se::DeviceMemoryBase output_data =
buffer_allocations.GetDeviceAddress(output_buffer);
DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction);
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
dim_nums.rhs_batch_dimensions_size());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape.rank());
int64 row_dim = dim_nums.lhs_batch_dimensions_size();
int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
int64 batch_size = std::accumulate(output_shape.dimensions().begin(),
output_shape.dimensions().end() - 2, 1,
std::multiplies<int64>());
// Check that the batch dims don't cover the last two dims.
for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
CHECK_NE(row_dim, batch_dim);
CHECK_NE(col_dim, batch_dim);
}
// Verify that the non-batch dimensions are minor-most. This is required for
// efficient access.
for (const auto* shape : {&lhs_shape, &rhs_shape, &output_shape}) {
CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
}
// BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
// matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
// their layout. Therefore, we should treat dimension 0 as row and dimension 1
// as column when mapping a matrix Dot to BLAS gemm.
int64 output_num_rows = output_shape.dimensions(row_dim);
int64 output_num_cols = output_shape.dimensions(col_dim);
// BLAS gemm expects the inputs and the output are in column-major order.
// Therefore, we need to convert dot between row-major matrices to that
// between column-major matrices. The key insight for the conversion is that,
// in linear storage, matrix M in column-major order is identical to the
// transpose of M in row-major order. In other words,
//
// column-major(M) = row-major(M^T).
//
// Leveraging this insight, we can perform dot between row-major matrices as
// follows.
//
// row-major(C)
// = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
// = gemm(column-major(B^T), column-major(A^T))
// = gemm(row-major(B), row-major(A))
//
// Although we do not modify the content of A and B in linear memory, we
// should use the dimensions of B^T and A^T when calling gemm. For example,
// the leading dimension of the LHS matrix of gemm is the number of rows in
// B^T and thus the number of columns in B.
auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape,
bool transpose) -> MatrixDescriptor {
bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
LayoutUtil::Minor(output_shape.layout(), row_dim);
return MatrixDescriptor{
data, static_cast<bool>(transpose ^ layout_mismatch),
shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
shape.dimensions(row_dim + static_cast<int64>(!is_row_major))};
};
MatrixDescriptor lhs_matrix = make_descriptor(
lhs_data, lhs_shape, dim_nums.lhs_contracting_dimensions(0) == row_dim);
MatrixDescriptor rhs_matrix = make_descriptor(
rhs_data, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
auto op_profiler = profiler->MakeScopedInstructionProfiler(
implements_whole_instruction ? hlo_instruction : nullptr);
if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
std::swap(lhs_matrix, rhs_matrix);
std::swap(output_num_cols, output_num_rows);
}
const MatrixDescriptor output_matrix{output_data, /*needs_transpose=*/false,
output_num_rows, output_num_cols};
// Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts
// to autotune this gemm to figure out the best algorithm.
PrimitiveType element_type = output_shape.element_type();
se::blas::ComputationType computation_type =
GetBlasComputationType(element_type);
std::string instr_descr =
hlo_instruction != nullptr ? hlo_instruction->ToString() : "<null>";
// Try finding the best algorithm by autotuning, or use older Gemm API
// if autotuning is disabled or has failed.
absl::optional<se::blas::AlgorithmType> best_algorithm;
if (xla_gpu_disable_autotune) {
VLOG(2) << "Autotuning disabled, using generic algorithm";
} else if (batch_size != 1) {
// TODO(b/112111608): Implement auto tune for batched gemm.
VLOG(2) << "Batch size is non-singular, using generic algorithm";
} else {
// Autotune may fail for various reasons (e.g. when when CUDA 8 and GPU
// sm_50 or older are used). In that case the returned best_algorithm
// will be an empty optional.
best_algorithm = DoGemmAutotune<Element>(
batch_size, lhs_matrix, rhs_matrix, output_matrix, stream, output_shape,
alpha, beta, instr_descr);
}
bool launch_ok = DoGemmWithAlgorithm<Element>(
batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha, beta,
computation_type, stream, best_algorithm,
/*output_profile_result=*/nullptr);
if (!launch_ok) {
return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
}
return Status::OK();
}
} // namespace
GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
@ -290,196 +445,30 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
beta_(beta),
implements_whole_instruction_(implements_whole_instruction) {}
absl::optional<se::blas::AlgorithmType> GemmThunk::GetGemmAlgorithm(
int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, se::DeviceMemoryBase output_data,
se::Stream* stream) {
PrimitiveType element_type = output_shape_.element_type();
se::blas::ComputationType computation_type =
GetBlasComputationType(element_type);
// TODO(b/112111608): Implement auto tune for batched gemm.
if (batch_size != 1) {
return absl::nullopt;
}
if (GetModuleConfig().debug_options().xla_gpu_disable_autotune()) {
VLOG(2) << "Auto-tune disabled, using generic algorithm.";
return absl::nullopt;
}
const string& device_name = stream->parent()->GetDeviceDescription().name();
auto autotune_it = autotune_results_.find(device_name);
if (autotune_it == autotune_results_.end()) {
VLOG(3) << "Starting autotune of GemmThunk " << GetThunkName();
// If the output buffer already contains a bias then autotune into a
// scratch buffer. This avoids overwriting the bias buffer. The scratch
// buffer may contain arbitrary garbage values.
se::DeviceMemoryBase scratch_data = output_data;
std::unique_ptr<se::TemporaryDeviceMemory<char>> scratch_mem;
if (beta_ != 0.0) {
auto temp_status = stream->AllocateTemporaryArray<char>(
ShapeUtil::ByteSizeOf(output_shape_));
if (!temp_status.ok()) {
return false;
}
scratch_mem = std::move(temp_status).ValueOrDie();
scratch_data = scratch_mem->device_memory();
}
const MatrixDescriptor scratch_descriptor(
scratch_data, false, output_matrix.num_rows, output_matrix.num_cols,
batch_size);
StatusOr<se::blas::AlgorithmType> best_algorithm = GetGemmAutotuneFn(
element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_, beta_,
computation_type, stream);
autotune_it = autotune_results_.insert({device_name, best_algorithm}).first;
if (autotune_it->second.ok()) {
VLOG(2) << "Autotune on GemmThunk " << GetThunkName()
<< " successful; best algorithm is "
<< best_algorithm.ValueOrDie();
} else {
VLOG(2) << "Autotune on GemmThunk " << GetThunkName()
<< " unsuccessful. Will use generic gemm.";
}
}
if (autotune_it->second.ok()) {
return autotune_it->second.ValueOrDie();
}
return absl::nullopt;
}
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
HloExecutionProfiler* profiler) {
VLOG(2) << "Executing a GemmThunk";
se::DeviceMemoryBase lhs_data =
buffer_allocations.GetDeviceAddress(lhs_buffer_);
se::DeviceMemoryBase rhs_data =
buffer_allocations.GetDeviceAddress(rhs_buffer_);
se::DeviceMemoryBase output_data =
buffer_allocations.GetDeviceAddress(output_buffer_);
DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
dim_nums.rhs_batch_dimensions_size());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank());
int64 row_dim = dim_nums.lhs_batch_dimensions_size();
int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
int64 batch_size = std::accumulate(output_shape_.dimensions().begin(),
output_shape_.dimensions().end() - 2, 1,
std::multiplies<int64>());
// Check that the batch dims don't cover the last two dims.
for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
CHECK_NE(row_dim, batch_dim);
CHECK_NE(col_dim, batch_dim);
}
// Verify that the non-batch dimensions are minor-most. This is required for
// efficient access.
for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) {
CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
}
// BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
// matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
// their layout. Therefore, we should treat dimension 0 as row and dimension 1
// as column when mapping a matrix Dot to BLAS gemm.
int64 output_num_rows = output_shape_.dimensions(row_dim);
int64 output_num_cols = output_shape_.dimensions(col_dim);
// BLAS gemm expects the inputs and the output are in column-major order.
// Therefore, we need to convert dot between row-major matrices to that
// between column-major matrices. The key insight for the conversion is that,
// in linear storage, matrix M in column-major order is identical to the
// transpose of M in row-major order. In other words,
//
// column-major(M) = row-major(M^T).
//
// Leveraging this insight, we can perform dot between row-major matrices as
// follows.
//
// row-major(C)
// = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
// = gemm(column-major(B^T), column-major(A^T))
// = gemm(row-major(B), row-major(A))
//
// Although we do not modify the content of A and B in linear memory, we
// should use the dimensions of B^T and A^T when calling gemm. For example,
// the leading dimension of the LHS matrix of gemm is the number of rows in
// B^T and thus the number of columns in B.
auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape,
bool transpose) -> MatrixDescriptor {
bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
LayoutUtil::Minor(output_shape_.layout(), row_dim);
return MatrixDescriptor(
data, transpose ^ layout_mismatch,
shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
shape.dimensions(row_dim + static_cast<int64>(!is_row_major)),
batch_size);
};
const MatrixDescriptor lhs_descriptor = make_descriptor(
lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim);
const MatrixDescriptor rhs_descriptor = make_descriptor(
rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim);
// Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to
// autotune this gemm to figure out the best algorithm.
auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, se::Stream* stream) {
PrimitiveType element_type = output_shape_.element_type();
se::blas::ComputationType computation_type =
GetBlasComputationType(element_type);
absl::optional<se::blas::AlgorithmType> best_algorithm = GetGemmAlgorithm(
batch_size, lhs_matrix, rhs_matrix, output_matrix, output_data, stream);
if (best_algorithm.has_value()) {
auto algorithm = best_algorithm.value();
VLOG(2) << "Using algorithm " << algorithm
<< " chosen by autotuning on GemmThunk " << GetThunkName();
return GetGemmWithAlgorithmFn(element_type)(
lhs_matrix, rhs_matrix, output_matrix, alpha_, beta_,
computation_type, algorithm, stream,
/*output_profile_result=*/nullptr);
auto fn = [&]() {
switch (output_shape_.element_type()) {
case F16:
return &ExecuteOnStreamParameterized<Eigen::half>;
case F32:
return &ExecuteOnStreamParameterized<float>;
case F64:
return &ExecuteOnStreamParameterized<double>;
case C64:
return &ExecuteOnStreamParameterized<std::complex<float>>;
case C128:
return &ExecuteOnStreamParameterized<std::complex<double>>;
default:
LOG(FATAL) << "Unsupported type.";
}
}();
// Autotune may fail for various reasons (e.g. when when CUDA 8 and GPU
// sm_50 or older are used). Use the older Gemm API in these case.
return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
alpha_, beta_, stream);
};
auto op_profiler = profiler->MakeScopedInstructionProfiler(
implements_whole_instruction_ ? hlo_instruction() : nullptr);
bool launch_ok;
if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) {
launch_ok = launch(lhs_descriptor, rhs_descriptor,
MatrixDescriptor(output_data, false, output_num_rows,
output_num_cols, batch_size),
stream);
} else {
launch_ok = launch(rhs_descriptor, lhs_descriptor,
MatrixDescriptor(output_data, false, output_num_cols,
output_num_rows, batch_size),
stream);
}
if (!launch_ok) {
return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
}
return Status::OK();
return fn(buffer_allocations, stream, profiler, lhs_buffer_, rhs_buffer_,
output_buffer_, lhs_shape_, rhs_shape_, output_shape_,
implements_whole_instruction_, hlo_instruction(), alpha_, beta_,
GetModuleConfig().debug_options().xla_gpu_disable_autotune());
}
} // namespace gpu

View File

@ -29,29 +29,6 @@ limitations under the License.
namespace xla {
namespace gpu {
namespace gemm_thunk_internal {
// Internal implementation details for GemmThunk:
// This struct contains the metadata of a matrix, e.g., its base address and
// dimensions.
struct MatrixDescriptor {
MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose,
int64 matrix_num_rows, int64 matrix_num_cols,
int64 matrix_batch_size)
: data(matrix_data),
transpose(needs_transpose),
num_rows(matrix_num_rows),
num_cols(matrix_num_cols),
batch_size(matrix_batch_size) {}
se::DeviceMemoryBase data;
bool transpose; // Whether this matrix needs to be transposed.
int64 num_rows;
int64 num_cols;
int64 batch_size;
};
} // namespace gemm_thunk_internal
// This class stores everything that StreamExecutor needs to launch a BLAS gemm.
// It is generated by IrEmitter.
//
@ -76,21 +53,7 @@ class GemmThunk : public Thunk {
se::Stream* stream,
HloExecutionProfiler* profiler) override;
bool WillAutotuneKernel(se::Stream* stream) override {
// We will autotune this kernel if we don't already have a autotune result
// for the stream device.
return autotune_results_.find(
stream->parent()->GetDeviceDescription().name()) ==
autotune_results_.end();
}
private:
absl::optional<se::blas::AlgorithmType> GetGemmAlgorithm(
int64 batch_size, gemm_thunk_internal::MatrixDescriptor lhs_matrix,
gemm_thunk_internal::MatrixDescriptor rhs_matrix,
gemm_thunk_internal::MatrixDescriptor output_matrix,
se::DeviceMemoryBase output_data, se::Stream* stream);
const BufferAllocation::Slice lhs_buffer_;
const BufferAllocation::Slice rhs_buffer_;
const BufferAllocation::Slice output_buffer_;
@ -103,20 +66,6 @@ class GemmThunk : public Thunk {
const double beta_;
const bool implements_whole_instruction_;
string GetThunkName() const {
return hlo_instruction() != nullptr ? hlo_instruction()->ToString()
: "<null>";
}
// Maps device names (StreamExecutor::DeviceDescription::name()) to autotune
// results. The map's value is the best algorithm we've found for this thunk
// on this device, or an error if none of the algorithms worked and we should
// use the regular gemm without an algorithm.
//
// TODO(b/112415150): Make this thread safe.
std::unordered_map<string, StatusOr<se::blas::AlgorithmType>>
autotune_results_;
};
} // namespace gpu

View File

@ -132,13 +132,6 @@ Status GpuExecutable::ExecuteThunks(
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
}
// If this thunk is about to autotune then wait for all currently executing
// thunks to finish. This reduces noise and thus the probability of
// choosing a suboptimal algorithm.
if (thunk->WillAutotuneKernel(stream)) {
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
}
VLOG(2) << "Executing the thunk for "
<< thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no;

View File

@ -153,7 +153,8 @@ bool IsInputFusibleScatter(const HloInstruction& instr) {
bool IsInputFusible(const HloInstruction& instr) {
// Input fusion only handles non-elemental reduction and scatter operations.
return IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr);
return instr.IsFusible() &&
(IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr));
}
bool IsLoopFusible(const HloInstruction& instr) {
@ -163,29 +164,42 @@ bool IsLoopFusible(const HloInstruction& instr) {
// compute the address of the GTE at the top of the kernel. Often we know the
// address of the GTE result statically, so we can do this without chasing any
// pointers.
return (instr.IsElementwise() && instr.operand_count() > 0) ||
instr.opcode() == HloOpcode::kBitcast ||
instr.opcode() == HloOpcode::kBroadcast ||
instr.opcode() == HloOpcode::kConcatenate ||
instr.opcode() == HloOpcode::kDynamicSlice ||
instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
(instr.opcode() == HloOpcode::kFusion &&
instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
instr.opcode() == HloOpcode::kGather ||
instr.opcode() == HloOpcode::kIota ||
instr.opcode() == HloOpcode::kPad ||
(instr.opcode() == HloOpcode::kReduce &&
!IsReductionFromOrToContiguousDimensions(instr)) ||
instr.opcode() == HloOpcode::kReduceWindow ||
instr.opcode() == HloOpcode::kReshape ||
instr.opcode() == HloOpcode::kReverse ||
instr.opcode() == HloOpcode::kSlice ||
instr.opcode() == HloOpcode::kTranspose;
return instr.IsFusible() &&
((instr.IsElementwise() && instr.operand_count() > 0) ||
instr.opcode() == HloOpcode::kBitcast ||
instr.opcode() == HloOpcode::kBroadcast ||
instr.opcode() == HloOpcode::kConcatenate ||
instr.opcode() == HloOpcode::kDynamicSlice ||
instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
(instr.opcode() == HloOpcode::kFusion &&
instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
instr.opcode() == HloOpcode::kGather ||
instr.opcode() == HloOpcode::kIota ||
instr.opcode() == HloOpcode::kPad ||
(instr.opcode() == HloOpcode::kReduce &&
!IsReductionFromOrToContiguousDimensions(instr)) ||
instr.opcode() == HloOpcode::kReduceWindow ||
instr.opcode() == HloOpcode::kReshape ||
instr.opcode() == HloOpcode::kReverse ||
instr.opcode() == HloOpcode::kSlice ||
instr.opcode() == HloOpcode::kTranspose);
}
bool IsFusible(const HloInstruction& instr) {
return IsInputFusible(instr) || IsLoopFusible(instr);
}
bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
// We can fuse reduces and loop fusions. Elementwise instructions can be fused
// with any other instruction.
// Note that scatter cannot be the root of a multi-output fusion because
// its emitter doesn't support it.
return instr.IsFusible() &&
(IsInputFusibleReduction(instr) ||
instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here.
instr.IsElementwise());
}
} // namespace gpu
} // namespace xla

View File

@ -69,6 +69,10 @@ bool IsInputFusibleScatter(const HloInstruction& instr);
bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
const HloInstruction& instr2);
// Whether `instr` is a candidate for sibling fusion or as a consumer in
// a producer-consumer multi-output fusion.
bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr);
} // namespace gpu
} // namespace xla

View File

@ -469,11 +469,12 @@ TEST_F(GpuFusibleTest,
TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p0 = f32[32,32,32]{2,1,0} parameter(0)
c0 = f32[] constant(0)
exp = f32[2,2,2]{2,1,0} exponential(p0)
reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp)
exp = f32[32,32,32]{2,1,0} exponential(p0)
reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
to_apply=scalar_add
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
})"))
.ValueOrDie();
const HloInstruction* reduce =
@ -573,24 +574,28 @@ TEST_F(GpuFusibleTest,
ShapesCompatibleForMultiOutputFusion_DifferentReduceDimensions) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_reduce_1 {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
p0.1 = f32[32,32,32]{2,1,0} parameter(0)
c0 = f32[] constant(0)
ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} p0.1, f32[] c0), dimensions={0}, to_apply=scalar_add
ROOT reduce = f32[32,32]{1,0} reduce(f32[32,32,32]{2,1,0} p0.1, f32[] c0),
dimensions={0}, to_apply=scalar_add
}
fused_reduce_2 {
p0.2 = f32[2,2,2]{2,1,0} parameter(0)
mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2)
p0.2 = f32[32,32,32]{2,1,0} parameter(0)
mul = f32[32,32,32]{2,1,0} multiply(f32[32,32,32]{2,1,0} p0.2,
f32[32,32,32]{2,1,0} p0.2)
c1 = f32[] constant(0)
ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={2}, to_apply=scalar_add
ROOT reduce = f32[32,32]{1,0} reduce(f32[32,32,32]{2,1,0} mul, f32[] c1),
dimensions={2}, to_apply=scalar_add
}
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p1 = f32[2,2,2]{2,1,0} parameter(1)
reduce_1 = f32[2,2]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1
reduce_2 = f32[2,2]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce_1, reduce_2)
p0 = f32[32,32,32]{2,1,0} parameter(0)
p1 = f32[32,32,32]{2,1,0} parameter(1)
reduce_1 = f32[32,32]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1
reduce_2 = f32[32,32]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
tuple(reduce_1, reduce_2)
})"))
.ValueOrDie();
const HloInstruction* fusion_1 =
@ -604,28 +609,31 @@ TEST_F(GpuFusibleTest,
ShapesCompatibleForMultiOutputFusion_NoReductionToVector) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_element_wise {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
p0.1 = f32[32,32,32]{2,1,0} parameter(0)
p1.1 = f32[32,32,32]{2,1,0} parameter(1)
ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
}
fused_reduce {
p0.2 = f32[2,2,2]{2,1,0} parameter(0)
mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2,
f32[2,2,2]{2,1,0} p0.2)
broadcast = f32[2,2,2,2]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
p0.2 = f32[32,32,32]{2,1,0} parameter(0)
mul = f32[32,32,32]{2,1,0} multiply(f32[32,32,32]{2,1,0} p0.2,
f32[32,32,32]{2,1,0} p0.2)
broadcast = f32[32,32,32,32]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
c1 = f32[] constant(0)
// Note that reduce is not a reduction-to-vector.
ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2,2]{3,2,1,0} broadcast,
ROOT reduce = f32[32,32]{1,0} reduce(f32[32,32,32,32]{3,2,1,0} broadcast,
f32[] c1), dimensions={1,3}, to_apply=scalar_add
}
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p1 = f32[2,2,2]{2,1,0} parameter(1)
element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
p0 = f32[32,32,32]{2,1,0} parameter(0)
p1 = f32[32,32,32]{2,1,0} parameter(1)
element_wise = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop,
calls=fused_element_wise
fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(element_wise),
kind=kLoop, calls=fused_reduce
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
tuple(fusion, element_wise)
})"))
.ValueOrDie();
const HloInstruction* fusion_1 =

View File

@ -82,6 +82,45 @@ bool DotImplementedAsGemm(const HloInstruction& dot) {
}
return false;
}
// Given a shape and a group of contiguous dimensions in the shape, returns
// a tuple of three values (major, middle, minor), where major is the size of
// the dimensions more major then the given dimensions, minor is the size of
// dimensions more minor then the given dimensions, and middle is the size of
// the given dimensions.
std::tuple<int64, int64, int64> PartitionShapeByMiddleDimensions(
const Shape& shape, DimensionVector dims_middle) {
CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
absl::Span<const int64> minor_to_major = LayoutUtil::MinorToMajor(shape);
int64 values[3] = {1, 1, 1};
enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
Segment cur_segment = kMinor;
// Iterate through the dimensions for the three segments in the order of
// minor, middle and major to accumulate the size of each segment.
absl::c_for_each(minor_to_major, [&](int64 cur_dim) {
if (cur_segment != kMajor) {
// Handle change of segments.
bool cur_dim_in_middle = absl::c_any_of(
dims_middle, [&](int64 dim) { return dim == cur_dim; });
if (cur_segment == kMinor) {
if (cur_dim_in_middle) {
cur_segment = kMiddle;
}
} else if (cur_segment == kMiddle) {
if (!cur_dim_in_middle) {
cur_segment = kMajor;
}
}
}
values[cur_segment] *= shape.dimensions(cur_dim);
});
return std::make_tuple(values[kMajor], values[kMiddle], values[kMinor]);
}
} // namespace
bool ImplementedAsGemm(const HloInstruction& hlo) {
@ -174,10 +213,74 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
dims_to_keep.push_back(dim);
}
}
return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
dims_to_keep) ||
LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
dims_to_keep) &&
!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
reduce.dimensions())) {
return false;
}
bool is_row_reduction;
DimensionVector dims_in_elem;
std::tie(is_row_reduction, dims_in_elem) =
GetReductionKindAndContiguousComponents(input->shape(),
reduce.dimensions());
if (is_row_reduction) {
// For row reduction, the tile block is 1 x tile_size_x, and we are reducing
// along tile_size_x which needs to be large enough to make the tiling
// implementation efficient.
return dims_in_elem[2] >= kWarpSize;
}
// For column reduction, the tile block is tize_size_y x tile_size_x, and we
// are reducing along tile_size_y. Both tile_size_x and tile_size_y need to be
// large enough to make the tiling implementation efficient.
return dims_in_elem[2] >= kWarpSize && dims_in_elem[1] >= kWarpSize;
}
std::pair<bool, DimensionVector> GetReductionKindAndContiguousComponents(
const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
DimensionVector dims_to_keep;
for (int64 dim = 0; dim < input_shape.rank(); ++dim) {
if (!absl::c_linear_search(dims_to_reduce, dim)) {
dims_to_keep.push_back(dim);
}
}
if (dims_to_keep.empty()) {
return std::make_pair(
true, DimensionVector{1, 1, ShapeUtil::ElementsIn(input_shape)});
}
if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
dims_to_keep)) {
int64 num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1;
std::tie(num_reduced_major, num_kept, num_reduced_minor) =
PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
if (num_kept == 1) {
return std::make_pair(
true, DimensionVector{1, 1, num_reduced_minor * num_reduced_major});
}
if (num_reduced_minor == 1) {
return std::make_pair(false,
DimensionVector{1, num_reduced_major, num_kept});
}
return std::make_pair(
true, DimensionVector{num_reduced_major, num_kept, num_reduced_minor});
}
int64 num_kept_major = 1, num_reduced = 1, num_kept_minor = 1;
std::tie(num_kept_major, num_reduced, num_kept_minor) =
PartitionShapeByMiddleDimensions(
input_shape,
DimensionVector(dims_to_reduce.begin(), dims_to_reduce.end()));
if (num_kept_minor == 1) {
return std::make_pair(true,
DimensionVector{1, num_kept_major, num_reduced});
}
return std::make_pair(
false, DimensionVector{num_kept_major, num_reduced, num_kept_minor});
}
// This emits a device-side call to

View File

@ -152,6 +152,26 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo);
// kept are contiguous in the input of the reduce instruction.
bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce);
// Given the input shape and dimensions to reduce for a reduction, returns
// <is_row_reduction, DimensionVector>:
// is_row_reduction: indicates whether the reduction is a row reduction or a
// column reduction.
// DimensionVector: contains the size of the three contiguous components for the
// reduction [depth, height, width]. For row reduction, height is the size of
// the dimensions to keep, depth is the size of the dimensions to reduce that
// are more major than the dimensions to keep, and width is the size of the
// dimensions to reduce that are more minor than the dimensions to keep. For
// column reduction, height is the size of dimensions to reduce, depth is the
// the size of the dimensions to keep that are more major than the dimensions
// to reduce, and width is the size of the dimensions to keep that are more
// minor than the dimensions to reduce.
//
// Prerequisite: the reduction instruction passes the check
// IsReductionFromOrToContiguousDimensions, which guarantees either the
// dimensions to reduce or the dimensions to keep are consecutive.
std::pair<bool, DimensionVector> GetReductionKindAndContiguousComponents(
const Shape& input_shape, absl::Span<const int64> dims_to_reduce);
// Emits call to "vprintf" with given format and arguments.
llvm::Value* EmitPrintf(absl::string_view fmt,
absl::Span<llvm::Value* const> arguments,

View File

@ -3546,105 +3546,6 @@ Status AreFusedReductionOutputsConsistent(
return Status::OK();
}
// Given a shape and a group of contiguous dimensions in the shape, returns
// a tuple of three values (major, middle, minor), where major is the size of
// the dimensions more major then the given dimensions, minor is the size of
// dimensions more minor then the given dimensions, and middle is the size of
// the given dimensions.
std::tuple<int64, int64, int64> PartitionShapeByMiddleDimensions(
const Shape& shape, DimensionVector dims_middle) {
CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
absl::Span<const int64> minor_to_major = LayoutUtil::MinorToMajor(shape);
int64 values[3] = {1, 1, 1};
enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
Segment cur_segment = kMinor;
// Iterate through the dimensions for the three segments in the order of
// minor, middle and major to accumulate the size of each segment.
absl::c_for_each(minor_to_major, [&](int64 cur_dim) {
if (cur_segment != kMajor) {
// Handle change of segments.
bool cur_dim_in_middle = absl::c_any_of(
dims_middle, [&](int64 dim) { return dim == cur_dim; });
if (cur_segment == kMinor) {
if (cur_dim_in_middle) {
cur_segment = kMiddle;
}
} else if (cur_segment == kMiddle) {
if (!cur_dim_in_middle) {
cur_segment = kMajor;
}
}
}
values[cur_segment] *= shape.dimensions(cur_dim);
});
return std::make_tuple(values[kMajor], values[kMiddle], values[kMinor]);
}
// Given the input shape and dimensions to reduce for a reduction, returns
// <is_row_reduction, DimensionVector>:
// is_row_reduction: indicates whether the reduction is a row reduction or a
// column reduction.
// DimensionVector: contains the size of the three contiguous components for the
// reduction [depth, height, width]. For row reduction, height is the size of
// the dimensions to keep, depth is the size of the dimensions to reduce that
// are more major than the dimensions to keep, and width is the size of the
// dimensions to reduce that are more minor than the dimensions to keep. For
// column reduction, height is the size of dimensions to reduce, depth is the
// the size of the dimensions to keep that are more major than the dimensions
// to reduce, and width is the size of the dimensions to keep that are more
// minor than the dimensions to reduce.
//
// Prerequisite: the reduction instruction passes the check
// IsReductionFromOrToContiguousDimensions, which guarantees either the
// dimensions to reduce or the dimensions to keep are consecutive.
std::pair<bool, DimensionVector> GetReductionKindAndContiguousComponents(
const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
DimensionVector dims_to_keep;
for (int64 dim = 0; dim < input_shape.rank(); ++dim) {
if (!absl::c_linear_search(dims_to_reduce, dim)) {
dims_to_keep.push_back(dim);
}
}
if (dims_to_keep.empty()) {
return std::make_pair(
true, DimensionVector{1, 1, ShapeUtil::ElementsIn(input_shape)});
}
if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
dims_to_keep)) {
int64 num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1;
std::tie(num_reduced_major, num_kept, num_reduced_minor) =
PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
if (num_kept == 1) {
return std::make_pair(
true, DimensionVector{1, 1, num_reduced_minor * num_reduced_major});
}
if (num_reduced_minor == 1) {
return std::make_pair(false,
DimensionVector{1, num_reduced_major, num_kept});
}
return std::make_pair(
true, DimensionVector{num_reduced_major, num_kept, num_reduced_minor});
}
int64 num_kept_major = 1, num_reduced = 1, num_kept_minor = 1;
std::tie(num_kept_major, num_reduced, num_kept_minor) =
PartitionShapeByMiddleDimensions(
input_shape,
DimensionVector(dims_to_reduce.begin(), dims_to_reduce.end()));
if (num_kept_minor == 1) {
return std::make_pair(true,
DimensionVector{1, num_kept_major, num_reduced});
}
return std::make_pair(
false, DimensionVector{num_kept_major, num_reduced, num_kept_minor});
}
// Returns true if all the transitive users of hlo before hitting users in
// use_chain_endings are elementwise operations.
bool AreUsersElementwise(const HloInstruction* hlo,

View File

@ -219,7 +219,6 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
}
builder.DisableUnitAtATime = false;
builder.DisableUnrollLoops = opt_level == 0;
builder.LoopVectorize = opt_level > 0;
builder.SLPVectorize = opt_level > 1 && size_level < 2;

View File

@ -45,11 +45,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
}
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
// We can fuse reduces and loop fusions. Elementwise instructions can be fused
// with any other instruction.
return instr->IsFusible() &&
(IsInputFusibleReduction(*instr) || instr->IsLoopFusion() ||
instr->IsElementwise());
return IsFusibleAsMultiOutputFusionRoot(*instr);
}
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,

View File

@ -400,11 +400,12 @@ TEST_F(MultiOutputFusionTest,
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p0 = f32[32,32,32]{2,1,0} parameter(0)
c0 = f32[] constant(0)
exp = f32[2,2,2]{2,1,0} exponential(p0)
reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp)
exp = f32[32,32,32]{2,1,0} exponential(p0)
reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
to_apply=scalar_add_computation
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
})"))
.ValueOrDie();
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
@ -420,18 +421,19 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_add {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
p0.1 = f32[32,32,32]{2,1,0} parameter(0)
p1.1 = f32[32,32,32]{2,1,0} parameter(1)
ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
}
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p1 = f32[2,2,2]{2,1,0} parameter(1)
p0 = f32[32,32,32]{2,1,0} parameter(0)
p1 = f32[32,32,32]{2,1,0} parameter(1)
c0 = f32[] constant(0)
add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add)
add = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
reduce = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
to_apply=scalar_add_computation
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, add)
})"))
.ValueOrDie();
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
@ -447,31 +449,37 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
p1.1 = f32[32,32,32]{2,1,0} parameter(1)
c0 = f32[] constant(0)
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
broadcast = f32[32,32,32]{2,1,0} broadcast(f32[] c0), dimensions={}
greater-than = pred[32,32,32]{2,1,0} compare(f32[32,32,32]{2,1,0} p1.1,
f32[32,32,32]{2,1,0} broadcast), direction=GT
p0.1 = f32[32,32,32]{2,1,0} parameter(0)
ROOT select = f32[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
greater-than, f32[32,32,32]{2,1,0} p0.1, f32[32,32,32]{2,1,0} broadcast)
}
fused_reduce {
p0.2 = f32[2,2,2]{2,1,0} parameter(0)
p0.2 = f32[32,32,32]{2,1,0} parameter(0)
c1 = f32[] constant(0)
r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation
mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2)
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
r1 = f32[32,32]{1,0} reduce(p0.2, c1), dimensions={2},
to_apply=scalar_add_computation
mul = f32[32,32,32]{2,1,0} multiply(p0.2, p0.2)
r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
to_apply=scalar_add_computation
ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
}
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p1 = f32[2,2,2]{2,1,0} parameter(1)
select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
p0 = f32[32,32,32]{2,1,0} parameter(0)
p1 = f32[32,32,32]{2,1,0} parameter(1)
select = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
calls=fused_reduce
gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
tuple(gte1, gte1, select)
})"))
.ValueOrDie();
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
@ -518,30 +526,36 @@ TEST_F(MultiOutputFusionTest,
ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f16[2,2,2]{2,1,0} parameter(1)
p1.1 = f16[32,32,32]{2,1,0} parameter(1)
c0 = f16[] constant(0)
broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={}
greater-than = pred[2,2,2]{2,1,0} compare(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast), direction=GT
p0.1 = f16[2,2,2]{2,1,0} parameter(0)
ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast)
broadcast = f16[32,32,32]{2,1,0} broadcast(f16[] c0), dimensions={}
greater-than = pred[32,32,32]{2,1,0} compare(f16[32,32,32]{2,1,0} p1.1,
f16[32,32,32]{2,1,0} broadcast), direction=GT
p0.1 = f16[32,32,32]{2,1,0} parameter(0)
ROOT select = f16[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
greater-than, f16[32,32,32]{2,1,0} p0.1, f16[32,32,32]{2,1,0} broadcast)
}
fused_reduce {
p0.2 = f16[2,2,2]{2,1,0} parameter(0)
convert = f32[2,2,2]{2,1,0} convert(p0.2)
p0.2 = f16[32,32,32]{2,1,0} parameter(0)
convert = f32[32,32,32]{2,1,0} convert(p0.2)
c1 = f32[] constant(0)
r1 = f32[2,2]{1,0} reduce(convert, c1), dimensions={2}, to_apply=scalar_add_computation
mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
r1 = f32[32,32]{1,0} reduce(convert, c1), dimensions={2},
to_apply=scalar_add_computation
mul = f32[32,32,32]{2,1,0} multiply(convert, convert)
r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
to_apply=scalar_add_computation
ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
}
ENTRY reduce {
p0 = f16[2,2,2]{2,1,0} parameter(0)
p1 = f16[2,2,2]{2,1,0} parameter(1)
select = f16[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
p0 = f16[32,32,32]{2,1,0} parameter(0)
p1 = f16[32,32,32]{2,1,0} parameter(1)
select = f16[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
fusion = (f32[32,32]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput,
calls=fused_reduce
gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f16[32,32,32]{2,1,0})
tuple(gte1, gte1, select)
})"))
.ValueOrDie();
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());

View File

@ -162,5 +162,23 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
return std::make_tuple(input_layout, filter_layout, output_layout);
}
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
// se::Platform*s are global singletons guaranteed to live forever.
static auto* mutexes =
new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
tensorflow::mutex>();
tensorflow::mutex_lock global_lock(mu);
auto it = mutexes
->emplace(std::piecewise_construct,
std::make_tuple(stream_exec->platform(),
stream_exec->device_ordinal()),
std::make_tuple())
.first;
return tensorflow::mutex_lock{it->second};
}
} // namespace gpu
} // namespace xla

View File

@ -45,6 +45,14 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
const Layout& input, const Layout& filter,
const Layout& output);
// Generates and returns a unique lock per each provided executor.
// Guarantees that blocks of code both holding a lock for the same provided
// executor (as given by this function) will not be running concurrently.
//
// This is used to prevent other XLA instances from trying to autotune on a
// device while another thread is using it.
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec);
} // namespace gpu
} // namespace xla

View File

@ -485,9 +485,9 @@ TEST_F(GpuKernelTilingTest,
}
ENTRY kernel_entry {
arg0 = f32[8,64,4]{2,1,0} parameter(0)
arg0 = f32[8,64,32]{2,1,0} parameter(0)
constant0 = f32[] constant(0)
ROOT reduce0 = f32[8,4]{0,1} reduce(arg0, constant0), dimensions={1},
ROOT reduce0 = f32[8,32]{0,1} reduce(arg0, constant0), dimensions={1},
to_apply=reduction0
})";
@ -506,6 +506,37 @@ TEST_F(GpuKernelTilingTest,
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
}
TEST_F(GpuKernelTilingTest, RowReductionWithSmallDimensionNotTiled) {
const char *const kHloString = R"(
HloModule reduction
reduction0 {
x0 = f32[] parameter(0)
y0 = f32[] parameter(1)
ROOT add0 = f32[] add(x0, y0)
}
ENTRY kernel_entry {
arg0 = f32[8,6,16]{2,1,0} parameter(0)
constant0 = f32[] constant(0)
ROOT reduce0 = f32[8,6]{1,0} reduce(arg0, constant0), dimensions={2},
to_apply=reduction0
})";
// Check that the kernel is not tiled by looking for llvm.nvvm.shfl.sync.down.
auto hlo_module =
ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @reduce
; CHECK-NOT: call float @llvm.nvvm.shfl.sync.down.f32
; CHECK: }
)",
/*match_optimized_ir=*/true);
// Check that the kernel runs correctly.
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -111,8 +111,8 @@ TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) {
hlo_module->AddEmbeddedComputation(embedded_builder.Build());
}
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
auto param_shape = ShapeUtil::MakeShape(F32, {32, 32});
auto reduce_shape = ShapeUtil::MakeShape(F32, {32});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "x"));
HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce(

View File

@ -85,10 +85,6 @@ class Thunk {
return Status::OK();
}
// Returns true if this kernel will autotune for the stream device the next
// time it is run.
virtual bool WillAutotuneKernel(se::Stream* /*stream*/) { return false; }
// Execute the kernel for the thunk on the given stream. This method must be
// called after Initialize and can be called multiple times over Thunk's
// lifetime. 'stream' and 'profiler' must be non-null.

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
@ -384,27 +385,11 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
StatusOr<HloInstruction*> ElideDegenerateDims(
HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
CHECK(absl::c_is_sorted(dims_to_elide));
const Shape& input_shape = operand->shape();
// First accumulate in reverse
std::vector<int64> new_shape_dim_bounds;
new_shape_dim_bounds.reserve(input_shape.dimensions_size() -
dims_to_elide.size());
int64 dims_to_elide_idx = dims_to_elide.size() - 1;
for (int64 i = input_shape.dimensions_size() - 1; i >= 0; i--) {
if (dims_to_elide_idx >= 0 && i == dims_to_elide[dims_to_elide_idx]) {
CHECK_EQ(input_shape.dimensions(i), 1);
dims_to_elide_idx--;
} else {
new_shape_dim_bounds.push_back(input_shape.dimensions(i));
}
}
absl::c_reverse(new_shape_dim_bounds);
Shape output_shape =
ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds);
return MakeReshapeHlo(output_shape, operand);
return MakeReshapeHlo(
ShapeUtil::FilterDimensions(
[&](int64 dim) { return !absl::c_linear_search(dims_to_elide, dim); },
operand->shape()),
operand);
}
StatusOr<HloInstruction*> InsertDegenerateDims(

View File

@ -957,7 +957,6 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
//
// Returns true if:
//
// * fusion is a loop or input fusion, AND
// * fusion_param is used by the root of dynamic-update-slice as the "base" of
// the update, i.e. the thing being updated, AND
// * all other uses of fusion_param are dynamic-slices that slice the same
@ -977,13 +976,6 @@ static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
// fusion must be a loop or input fusion.
auto kind = fusion->fusion_kind();
if (kind != HloInstruction::FusionKind::kLoop &&
kind != HloInstruction::FusionKind::kInput) {
return false;
}
// fusion_param must be used by the root as the "base" of the
// dynamic-update-slice. The natural way to check this would be
//

View File

@ -423,10 +423,12 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
#ifndef NDEBUG
const Literal* input_literal = arg_literals_[parameter->parameter_number()];
VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape()))
<< "parameter shape is: " << ShapeUtil::HumanString(parameter->shape())
DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
input_literal->shape()))
<< "parameter shape is: "
<< ShapeUtil::HumanStringWithLayout(parameter->shape())
<< ", but input literal shape is: "
<< ShapeUtil::HumanString(input_literal->shape());
<< ShapeUtil::HumanStringWithLayout(input_literal->shape());
#endif
return Status::OK();

View File

@ -2751,7 +2751,12 @@ template <typename Visitor>
static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
const InternalCompareFunction* operand_order,
bool ignore_control_predecessors) {
visitor->ReserveVisitStates(root->GetModule()->instruction_count());
// Calculating the instruction count within a module can be expensive on large
// models so only do it if the visit state is empty. This will help when the
// same visitor is reused across many computations of a single module.
if (visitor->VisitStateSize() == 0) {
visitor->ReserveVisitStates(root->GetModule()->instruction_count());
}
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
//

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@ -122,8 +123,9 @@ Status HloModuleGroupMetadata::Build() {
// Visit the computations in postorder so that the companion information grows
// from inner computations to outer ones.
for (HloModule* module : modules_) {
FunctionVisitor function_visitor(visitor);
for (HloComputation* computation : module->MakeComputationPostOrder()) {
TF_RETURN_IF_ERROR(computation->Accept(visitor));
TF_RETURN_IF_ERROR(computation->Accept(&function_visitor));
}
}
TF_RETURN_IF_ERROR(VerifyCompanionSets());
@ -370,8 +372,9 @@ Status HloModuleGroupMetadata::RecordInstructions() {
};
for (HloModule* module : modules_) {
FunctionVisitor function_visitor(visitor);
for (auto* computation : module->computations()) {
TF_RETURN_IF_ERROR(computation->Accept(visitor));
TF_RETURN_IF_ERROR(computation->Accept(&function_visitor));
}
}
VLOG(2) << "Created " << channels_.size() << " channels";

View File

@ -2671,7 +2671,7 @@ bool HloParser::ParseAttributeHelper(
if (!ParseAttributeName(&name)) {
return Error(loc, "error parsing attributes");
}
VLOG(1) << "Parsing attribute " << name;
VLOG(3) << "Parsing attribute " << name;
if (!seen_attrs->insert(name).second) {
return Error(loc, StrFormat("attribute %s already exists", name));
}
@ -2943,7 +2943,7 @@ bool HloParser::ParseAttributeAsProtoMessageHelper(
if (!ParseAttributeName(&name)) {
return Error(loc, "error parsing attributes");
}
VLOG(1) << "Parsing attribute " << name;
VLOG(3) << "Parsing attribute " << name;
if (!seen_attrs->insert(name).second) {
return Error(loc, StrFormat("attribute %s already exists", name));
}
@ -3650,7 +3650,7 @@ bool HloParser::CanBeShape() {
}
bool HloParser::ParseName(string* result) {
VLOG(1) << "ParseName";
VLOG(3) << "ParseName";
if (lexer_.GetKind() != TokKind::kIdent &&
lexer_.GetKind() != TokKind::kName) {
return TokenError("expects name");
@ -3670,7 +3670,7 @@ bool HloParser::ParseAttributeName(string* result) {
}
bool HloParser::ParseString(string* result) {
VLOG(1) << "ParseString";
VLOG(3) << "ParseString";
if (lexer_.GetKind() != TokKind::kString) {
return TokenError("expects string");
}
@ -3784,7 +3784,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) {
}
bool HloParser::ParseOpcode(HloOpcode* result) {
VLOG(1) << "ParseOpcode";
VLOG(3) << "ParseOpcode";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects opcode");
}
@ -3800,7 +3800,7 @@ bool HloParser::ParseOpcode(HloOpcode* result) {
}
bool HloParser::ParseFftType(FftType* result) {
VLOG(1) << "ParseFftType";
VLOG(3) << "ParseFftType";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects fft type");
}
@ -3829,7 +3829,7 @@ bool HloParser::ParseComparisonDirection(ComparisonDirection* result) {
}
bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
VLOG(1) << "ParseFusionKind";
VLOG(3) << "ParseFusionKind";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects fusion kind");
}
@ -3846,7 +3846,7 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
}
bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
VLOG(1) << "ParseRandomDistribution";
VLOG(3) << "ParseRandomDistribution";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects random distribution");
}
@ -3863,7 +3863,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
}
bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
VLOG(1) << "ParsePrecision";
VLOG(3) << "ParsePrecision";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects random distribution");
}
@ -3880,7 +3880,7 @@ bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
}
bool HloParser::ParseInt64(int64* result) {
VLOG(1) << "ParseInt64";
VLOG(3) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
return TokenError("expects integer");
}
@ -3969,7 +3969,7 @@ bool HloParser::ParseBool(bool* result) {
}
bool HloParser::ParseToken(TokKind kind, const string& msg) {
VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg;
VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
if (lexer_.GetKind() != kind) {
return TokenError(msg);
}

View File

@ -41,6 +41,8 @@ class HloPassInterface {
// module group. Ideally, the module group variant would be named "Run" as
// well, but C++ does not handle overloaded virtual methods well.
virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0;
virtual bool IsPassPipeline() { return false; }
};
// Base class for passes which are module-scoped.

View File

@ -58,14 +58,21 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal(
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
bool changed = false;
for (HloPassInterface* pass : passes) {
VLOG(1) << " HLO pass " << pass->name();
absl::string_view pass_name = pass->name();
VLOG(1) << " HLO pass " << pass_name;
MaybeDumpHlo(*hlo,
/*after_pass_name=*/last_pass_name,
/*before_pass_name=*/pass->name());
/*before_pass_name=*/pass_name);
if (!pass->IsPassPipeline()) {
compilation_stats_->StartPass(pass_name);
}
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
changed |= pass_changed;
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
last_pass_name = string(pass->name());
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
last_pass_name = string(pass_name);
if (!pass->IsPassPipeline()) {
compilation_stats_->EndPass(pass_name);
}
}
MaybeDumpHlo(*hlo,
/*after_pass_name=*/last_pass_name,

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/compilation_stats.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -34,7 +35,14 @@ namespace xla {
// Pipeline of HLO passes.
class HloPassPipeline : public HloPassInterface {
public:
explicit HloPassPipeline(const string& name) : name_(name) {}
explicit HloPassPipeline(const string& name,
CompilationStats* compilation_stats = nullptr)
: name_(name), compilation_stats_(compilation_stats) {
if (compilation_stats == nullptr) {
empty_compilation_stats_ = CompilationStats::MakeNoopStats();
compilation_stats_ = empty_compilation_stats_.get();
}
}
absl::string_view name() const override { return name_; }
// Add a pass to the pipeline. It should be called with the arguments for the
@ -65,6 +73,8 @@ class HloPassPipeline : public HloPassInterface {
StatusOr<bool> Run(HloModule* module) override;
StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
bool IsPassPipeline() override { return true; }
private:
// Returns the set of passes which are enabled. DebugOptions can selectively
// disable passes via --xla_disable_hlo_passes flag.
@ -105,6 +115,11 @@ class HloPassPipeline : public HloPassInterface {
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
bool run_called_ = false;
CompilationStats* compilation_stats_;
// Default stats instance for when one is not passed in the constructor.
// Use via compilation_stats_, not directly.
std::unique_ptr<CompilationStats> empty_compilation_stats_;
};
} // namespace xla

View File

@ -35,7 +35,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
const ShapeIndex& index) {
BufferAllocation::Slice buffer_slice;
if (hlo.opcode() == HloOpcode::kParameter &&
hlo.parent() == hlo.parent()->parent()->entry_computation()) {
hlo.parent() == module_.entry_computation()) {
// Entry computation parameters may alias with each other but may not alias
// with our temporary buffers.
buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0);
@ -78,12 +78,9 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
.xla_llvm_enable_invariant_load_metadata()) {
// Parameters of the entry computation are never stored to, loading from a
// parameter pointer should always return the same result within a loop.
if (hlo.opcode() == HloOpcode::kParameter) {
const std::vector<HloInstruction*>& parameter_instructions =
module_.entry_computation()->parameter_instructions();
if (absl::c_linear_search(parameter_instructions, &hlo)) {
array->MarkInvariantOverWholeProgram(context_);
}
if (hlo.opcode() == HloOpcode::kParameter &&
hlo.parent() == module_.entry_computation()) {
array->MarkInvariantOverWholeProgram(context_);
}
}
}

View File

@ -2544,7 +2544,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
};
// Check the shapes of computation parameters and return types.
if (!ShapeUtil::Equal(condition.result(), ShapeUtil::MakeShape(PRED, {}))) {
if (!ShapeUtil::Compatible(condition.result(),
ShapeUtil::MakeShape(PRED, {}))) {
return InvalidArgument("Condition must return a boolean; got %s.",
shape_string());
}
@ -2564,8 +2565,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& branch_index,
absl::Span<const ProgramShape> branch_computations,
absl::Span<const Shape> branch_operands) {
if (!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(S32, {}))) {
if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
return InvalidArgument("branch_index must be bool or int32; got %s.",
ShapeUtil::HumanString(branch_index));
}

View File

@ -285,7 +285,7 @@ Status TransferManager::TransferBufferFromDevice(
void* destination) {
if (source.size() < size) {
return FailedPrecondition(
"Source allocation on device not large enough for data tranfer: "
"Source allocation on device not large enough for data transfer: "
"%d < %d",
source.size(), size);
}
@ -298,7 +298,7 @@ Status TransferManager::TransferBufferToDevice(
se::DeviceMemoryBase* destination) {
if (destination->size() < size) {
return FailedPrecondition(
"Destination allocation on device not large enough for data tranfer: "
"Destination allocation on device not large enough for data transfer: "
"%d < %d",
destination->size(), size);
}
@ -336,4 +336,9 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
return std::move(shaped_buffer);
}
StatusOr<Shape> TransferManager::ChooseCompactLayoutForShape(
const Shape& host_shape) const {
return LayoutUtil::GetWithDefaultLayout(host_shape);
}
} // namespace xla

View File

@ -216,6 +216,15 @@ class TransferManager {
// region for a host-to-device transfer.
virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0;
// Chooses a compact layout for 'shape', ignoring any existing layout on
// 'shape'. What "reasonable" means is left up to the backend. The
// intended use case is to choose a layout that avoids excessive padding on
// devices that have tiled memory architectures.
// The default implementation always picks a default (major-to-minor) layout.
// Fails if 'shape' cannot be represented by the device.
virtual StatusOr<Shape> ChooseCompactLayoutForShape(
const Shape& host_shape) const;
// Allocates a ScopedShapedBuffer which can hold data with the given on-host
// shape. The on-device shape may be different as indicated by
// HostShapeToDeviceShape.

View File

@ -200,6 +200,12 @@ class Shape {
bool operator==(const Shape& other) const { return Equal()(*this, other); }
bool operator!=(const Shape& other) const { return !(*this == other); }
template <typename H>
friend H AbslHashValue(H h, const Shape& s) {
return H::combine(std::move(h), s.element_type_, s.dimensions_,
s.dynamic_dimensions_, s.tuple_shapes_, s.layout_);
}
private:
// The element type of this shape (tuple, array, etc).
PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;

View File

@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape.h"
#include <numeric>
#include "absl/hash/hash_testing.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
@ -210,5 +212,11 @@ TEST_F(ShapeTest, ProgramShapeToString) {
prog.ToString());
}
TEST_F(ShapeTest, SupportsAbslHash) {
EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly(
{opaque_, token_, scalar_, scalar_with_tile_, matrix_, matrix2_, tuple_,
nested_tuple_, dyanmic_matrix_}));
}
} // namespace
} // namespace xla

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <math.h>
#include <algorithm>
#include <memory>
#include <new>
@ -42,7 +43,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test_benchmark.h"
@ -604,10 +604,9 @@ std::unique_ptr<HloComputation> MakeReduceTestComputation() {
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
auto hlo_module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
auto const0 = builder.AddInstruction(
HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {32}), 0));
auto const1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
@ -618,7 +617,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
HloInstruction::FusionKind::kInput);
EXPECT_TRUE(
LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(15),
LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(496),
ExecuteAndTransfer(std::move(hlo_module), {})));
}
@ -896,8 +895,7 @@ void BM_ParallelFusion(int num_iters) {
// Initialize thread pool.
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
intra_op_parallelism_threads);
tensorflow::EigenThreadPoolWrapper tp(&pool);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
// Initialize ExecutableRunOptions.
ExecutableRunOptions options;

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
@ -108,12 +107,10 @@ struct LocalClientTestBase::EigenThreadPoolWrapper {
explicit EigenThreadPoolWrapper()
: pool(new tensorflow::thread::ThreadPool(
tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
device(new Eigen::ThreadPoolDevice(wrapper.get(),
wrapper->NumThreads())) {}
device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
pool->NumThreads())) {}
std::unique_ptr<tensorflow::thread::ThreadPool> pool;
std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};

View File

@ -295,255 +295,191 @@ XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) {
const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p0 = f32[32,32,32]{2,1,0} parameter(0)
c0 = f32[] constant(0)
r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
r1 = f32[32,32]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
mul = f32[32,32,32]{2,1,0} multiply(p0, p0)
c1 = f32[] constant(5)
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
}
ENTRY reduce {
p = f32[2,2,2]{2,1,0} parameter(0)
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
calls=fused_reduce
p = f32[32,32,32]{2,1,0} parameter(0)
ROOT fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(p), kind=kInput,
calls=fused_reduce
})");
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
result));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) {
const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p0 = f32[32,32,32]{2,1,0} parameter(0)
c0 = f32[] constant(0)
r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
r1 = f32[32,32]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
mul = f32[32,32,32]{2,1,0} multiply(p0, p0)
c1 = f32[] constant(5)
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
}
ENTRY reduce {
p = f32[2,2,2]{2,1,0} parameter(0)
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
calls=fused_reduce
p = f32[32,32,32]{2,1,0} parameter(0)
ROOT fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(p), kind=kInput,
calls=fused_reduce
})");
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
result));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) {
const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p0 = f32[2,32,32]{2,1,0} parameter(0)
c0 = f32[] constant(0)
r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
r1 = f32[32]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
mul = f32[2,32,32]{2,1,0} multiply(p0, p0)
c1 = f32[] constant(1.17549e-38)
r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max
r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add
ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3)
r2 = f32[32]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max
r3 = f32[32]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add
ROOT tuple = (f32[32]{0}, f32[32]{0}, f32[32]{0}) tuple(r1, r2, r3)
}
ENTRY reduce {
p = f32[2,2,2]{2,1,0} parameter(0)
ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
calls=fused_reduce
p = f32[2,32,32]{2,1,0} parameter(0)
ROOT fusion = (f32[32]{0}, f32[32]{0}, f32[32]{0}) fusion(p), kind=kInput,
calls=fused_reduce
})");
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
LiteralUtil::CreateR1<float>({36, 64}),
LiteralUtil::CreateR1<float>({66, 138})),
result));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) {
const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p0 = f32[2,32,32]{2,1,0} parameter(0)
c0 = f32[] constant(0)
r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
r1 = f32[2,32]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
mul = f32[2,32,32]{2,1,0} multiply(p0, p0)
c1 = f32[] constant(5)
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0})
r2 = f32[2,32]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
ROOT tuple = (f32[2,32,32]{2,1,0}, f32[2,32]{1,0}, f32[2,2]{1,0})
tuple(p0, r1, r2)
}
ENTRY reduce {
p = f32[2,2,2]{2,1,0} parameter(0)
ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
kind=kInput, calls=fused_reduce
p = f32[2,32,32]{2,1,0} parameter(0)
ROOT fusion = (f32[2,32,32]{2,1,0}, f32[2,32]{1,0}, f32[2,32]{1,0})
fusion(p), kind=kInput, calls=fused_reduce
})");
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
result));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) {
const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p0 = f32[32,32,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
r1 = f32[32,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
mul = f32[32,32,2]{2,1,0} multiply(p0, p0)
c1 = f32[] constant(5)
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0})
r2 = f32[32,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
ROOT tuple = (f32[32,2]{1,0}, f32[32,32,2]{2,1,0}, f32[32,2]{1,0})
tuple(r1, mul, r2)
}
ENTRY reduce {
p = f32[2,2,2]{2,1,0} parameter(0)
ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
kind=kInput, calls=fused_reduce
p = f32[32,32,2]{2,1,0} parameter(0)
ROOT fusion = (f32[32,2]{1,0}, f32[32,32,2]{2,1,0}, f32[32,2]{1,0})
fusion(p), kind=kInput, calls=fused_reduce
})");
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
result));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) {
const string testcase = absl::StrCat(kScalarOps, R"(
const string testcase = R"(
HloModule m, is_scheduled=true
Add {
lhsadd = f32[] parameter(0)
rhsadd = f32[] parameter(1)
ROOT add = f32[] add(lhsadd, rhsadd)
}
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p0 = f32[2,32,32]{2,1,0} parameter(0)
c0 = f32[] constant(0)
r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
r1 = f32[32]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
mul = f32[2,32,32]{2,1,0} multiply(p0, p0)
c1 = f32[] constant(5)
b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={}
mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1)
ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0})
tuple(r1, mul, mul2)
b1 = f32[2,32,32]{2,1,0} broadcast(c1), dimensions={}
mul2 = f32[2,32,32]{2,1,0} multiply(p0, b1)
ROOT tuple = (f32[32]{0}, f32[2,32,32]{2,1,0}, f32[2,32,32]{2,1,0})
tuple(r1, mul, mul2)
}
ENTRY reduce {
p = f32[2,2,2]{2,1,0} parameter(0)
ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
kind=kInput, calls=fused_reduce
})");
p = f32[2,32,32]{2,1,0} parameter(0)
ROOT fusion = (f32[32]{0}, f32[2,32,32]{2,1,0}, f32[2,32,32]{2,1,0})
fusion(p), kind=kInput, calls=fused_reduce
})";
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<float>({14, 22}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR3<float>(
{{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
result));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) {
const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
p0 = f32[2,32,32]{2,1,0} parameter(0)
init1 = f32[] parameter(1)
init2 = f32[] parameter(2)
r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add
r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
r1 = f32[2,32]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add
r2 = f32[2,32]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max
ROOT tuple = (f32[2,32]{1,0}, f32[2,32]{1,0}) tuple(r1, r2)
}
ENTRY reduce {
p = f32[2,2,2]{2,1,0} parameter(0)
p = f32[2,32,32]{2,1,0} parameter(0)
i = f32[] parameter(1)
j = f32[] parameter(2)
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
calls=fused_reduce
ROOT fusion = (f32[2,32]{1,0}, f32[2,32]{1,0}) fusion(p, i, j),
kind=kInput, calls=fused_reduce
})");
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
auto init1 = LiteralUtil::CreateR0<float>(5);
auto init2 = LiteralUtil::CreateR0<float>(6);
Literal result =
ExecuteNoHloPasses(std::move(module), {&param, &init1, &init2});
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
result));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) {
const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) {
p0 = f16[2,2,2]{2,1,0} parameter(0)
convert = f32[2,2,2]{2,1,0} convert(p0)
fused_reduce (p0: f16[2,32,32]) -> (f32[2,32], f32[2,32], f16[2,32,32]) {
p0 = f16[2,32,32]{2,1,0} parameter(0)
convert = f32[2,32,32]{2,1,0} convert(p0)
c0 = f32[] constant(0)
r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add
mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
r1 = f32[2,32]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add
mul = f32[2,32,32]{2,1,0} multiply(convert, convert)
c1 = f32[] constant(5)
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0})
r2 = f32[2,32]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
ROOT tuple = (f32[2,32]{1,0}, f32[2,32]{1,0}, f16[2,32,32]{2,1,0})
tuple(r1, r2, p0)
}
ENTRY reduce {
p = f16[2,2,2]{2,1,0} parameter(0)
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p),
p = f16[2,32,32]{2,1,0} parameter(0)
ROOT fusion = (f32[2,32]{1,0}, f32[2,32]{1,0}, f16[2,32,32]{2,1,0}) fusion(p),
kind=kInput, calls=fused_reduce
})");
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param = LiteralUtil::CreateR3<Eigen::half>(
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
LiteralUtil::CreateR3<Eigen::half>(
{{{Eigen::half(1), Eigen::half(2)},
{Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)},
{Eigen::half(7), Eigen::half(8)}}})),
result));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
} // namespace

View File

@ -6,31 +6,18 @@ package(
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load(
"//tensorflow:tensorflow.bzl",
"py_test",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
"tf_kernel_library",
)
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
py_library(
name = "batch_py",
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:batch_ops",
"//tensorflow/python:batch_ops_gen",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:platform",
"//tensorflow/python:script_ops",
"//tensorflow/python:util",
],
)

View File

@ -18,14 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import gen_batch_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_batch_ops import *
# pylint: enable=wildcard-import
# pylint: disable=unused-import
from tensorflow.python.ops.batch_ops import batch
from tensorflow.python.ops.batch_ops import batch_function
from tensorflow.python.ops.batch_ops import unbatch
# pylint: enable=unused-import
@ops.RegisterGradient("Batch")
@ -55,85 +54,6 @@ def _UnbatchGrad(op, grad): # pylint: disable=invalid-name
]
def batch_function(num_batch_threads,
max_batch_size,
batch_timeout_micros,
allowed_batch_sizes=None,
max_enqueued_batches=10):
"""Batches the computation done by the decorated function.
So, for example, in the following code
```python
@batch_function(1, 2, 3)
def layer(a):
return tf.matmul(a, a)
b = layer(w)
```
if more than one session.run call is simultaneously trying to compute `b`
the values of `w` will be gathered, non-deterministically concatenated
along the first axis, and only one thread will run the computation. See the
documentation of the `Batch` op for more details.
Assumes that all arguments of the decorated function are Tensors which will
be batched along their first dimension.
SparseTensor is not supported. The return value of the decorated function
must be a Tensor or a list/tuple of Tensors.
Args:
num_batch_threads: Number of scheduling threads for processing batches
of work. Determines the number of batches processed in parallel.
max_batch_size: Batch sizes will never be bigger than this.
batch_timeout_micros: Maximum number of microseconds to wait before
outputting an incomplete batch.
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
does nothing. Otherwise, supplies a list of batch sizes, causing the op
to pad batches up to one of those sizes. The entries must increase
monotonically, and the final entry must equal max_batch_size.
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
Returns:
The decorated function will return the unbatched computation output Tensors.
"""
def decorator(fn): # pylint: disable=missing-docstring
def decorated(*args): # pylint: disable=missing-docstring
@function.defun(autograph=False)
def computation(*computation_args):
return fn(*computation_args)
computation = computation.get_concrete_function(
*[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i))
for i, x in enumerate(args)])
with ops.name_scope("batch") as name:
for a in args:
if not isinstance(a, ops.Tensor):
raise ValueError("All arguments to functions decorated with "
"`batch_function` are supposed to be Tensors; "
"found %s" % repr(a))
return gen_batch_ops.batch_function(
num_batch_threads=num_batch_threads,
max_batch_size=max_batch_size,
batch_timeout_micros=batch_timeout_micros,
allowed_batch_sizes=allowed_batch_sizes,
max_enqueued_batches=max_enqueued_batches,
shared_name=name,
f=computation,
in_tensors=list(args),
captured_tensors=computation.captured_inputs,
Tout=[o.dtype for o in computation.outputs])
return decorated
return decorator
def batch_function_v1(num_batch_threads,
max_batch_size,
batch_timeout_micros,

View File

@ -23,12 +23,8 @@ import time
from tensorflow.contrib.batching.python.ops import batch_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
@ -41,153 +37,6 @@ def delayed_plus1(x):
class BatchOpsTest(test.TestCase):
"""Tests for batch_ops.{un,}batch."""
def testBasicBatch(self):
"""Tests that a single batched tensor executes together and only once."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=0,
batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched, index], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched, index], feed_dict={inp: [2]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0][0]
index_t = thread_results[1]
empty_b = main_results[0][0]
empty_m = main_results[1]
else:
batch_t = main_results[0][0]
index_t = main_results[1]
empty_b = thread_results[0][0]
empty_m = thread_results[1]
# Check that both the inputs made it out exactly once.
self.assertAllEqual(sorted(batch_t), (1, 2))
# Check that we get 2 rows in the index tensor.
self.assertEqual(len(index_t), 2)
# Check that the other ones are empty.
self.assertEqual(len(empty_b), 0)
self.assertEqual(len(empty_m), 0)
def testBatchWithPadding(self):
"""Test that batching with padding up to an allowed batch size works."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[5, 10],
grad_timeout_micros=0, batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched, index], feed_dict={inp: [1, 3]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched, index], feed_dict={inp: [2, 4]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0][0]
else:
batch_t = main_results[0][0]
# Check that the batch tensor incorporates the padding.
self.assertEqual(len(batch_t), 5)
def testMultipleBatch(self):
"""Tests that multiple batched tensors execute together."""
with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
[inp0, inp1],
num_batch_threads=1,
max_batch_size=2,
batch_timeout_micros=36000000,
grad_timeout_micros=0,
batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched], feed_dict={inp0: [1],
inp1: [2]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0]
empty_t = main_results[0]
else:
batch_t = main_results[0]
empty_t = thread_results[0]
# Assert that the tensors were batched together.
self.assertAllEqual(sorted(batch_t[0]), [1, 2])
self.assertAllEqual(sorted(batch_t[1]), [2, 3])
self.assertAllEqual(empty_t[0], [])
self.assertAllEqual(empty_t[1], [])
def testIllegalBatchDifferentDim0Sizes(self):
"""Tests illegally feeding tensors with different dim0 sizes."""
with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp0, inp1], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="")
with self.assertRaises(Exception) as raised:
_ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]})
self.assertGreater(
raised.exception.message.find("must have equal 0th-dimension size"),
0)
def testBasicUnbatch(self):
"""Tests that batch and unbatch work together."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
grad_timeout_micros=0, batching_queue="")
computation = batched[0] + 1
result = batch_ops.unbatch(computation, index, id_t,
timeout_micros=1000000, shared_name="unbatch")
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBasicUnbatchV1Decorated(self):
"""Tests that the batch_function_v1 decorator works."""
with self.cached_session() as sess:
@ -210,206 +59,6 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
self.assertTrue(in_t.shape is not None)
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchDecoratedWithCapturedInput(self):
"""Tests that the batch_function decorator works."""
with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return in_t + captured_inp0 - captured_inp1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = gen_batch_ops.batch_function(
[inp],
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000,
Tout=[dtypes.int32],
f=computation,
captured_tensors=computation.captured_inputs)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithCapturedInput(self):
"""Tests that batch_function op works with captured input."""
with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32)
def computation(inp):
return inp + captured_inp0 - captured_inp1
result = gen_batch_ops.batch_function(
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
batching_queue="",
f=computation,
in_tensors=[inp],
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg])
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithInputError(self):
"""Tests that batch_function op works with error in the inputs."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
def computation(in0, in1):
return in0 + in1
result = gen_batch_ops.batch_function(
[inp], # computation actually expects 2 inputs.
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
batching_queue="",
f=computation,
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg])
with self.assertRaisesRegexp(InvalidArgumentError,
".*2 arguments.*but 1.*"):
sess.run([result], feed_dict={inp: [2]})
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return array_ops.reshape(in_t, [-1]) + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [[2]]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=0,
batching_queue="")
computation = batched[0] + 1
timeout_micros = 10
result = batch_ops.unbatch(computation, index, id_t, timeout_micros,
shared_name="shared_unbatch")
# Set up a parallel pipeline that delays the computation, but uses the
# same unbatch resource object as the non-delayed pipeline.
computation_delayed = script_ops.py_func(delayed_plus1,
[batched[0]],
dtypes.int32)
result_delayed = batch_ops.unbatch(computation_delayed,
index,
id_t,
timeout_micros,
shared_name="shared_unbatch")
thread_results = []
def worker():
# A first call using the non-delayed pipeline. The batcher will send an
# empty tensor along the non-delayed pipeline.
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
time.sleep(0.1) # Ensure the thread's call starts first.
# A second call using the delayed pipeline. The batcher will send the
# batched tensor along the delayed pipeline, thus delaying the arrival of
# the batched tensor at the unbatch op, relative to the empty tensor.
#
# TODO(olston, apassos): Avoid relying on the order in which the batch op
# emits the empty tensor versus the batched one.
_ = sess.run([result_delayed], feed_dict={inp: [2]})
worker_thread.join()
# The thread's call should hit the timeout, and thus get 0 results.
self.assertEqual(len(thread_results), 0)
def testUnbatchGrad(self):
"""Tests that batch and unbatch are differentiable."""
with self.cached_session() as sess:
@ -434,6 +83,5 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [4])
if __name__ == "__main__":
test.main()

View File

@ -108,8 +108,8 @@ class NumpyState(base.Trackable):
except AttributeError:
value = _NumpyWrapper(value)
self._track_trackable(value, name=name, overwrite=True)
elif (name not in ("_setattr_tracking", "_update_uid")
and getattr(self, "_setattr_tracking", True)):
elif (name not in ("_self_setattr_tracking", "_self_update_uid")
and getattr(self, "_self_setattr_tracking", True)):
# Mixing restore()-created attributes with user-added trackable
# objects is tricky, since we can't use the `_lookup_dependency` trick to
# re-create attributes (we might accidentally steal the restoration for
@ -154,4 +154,3 @@ class _NumpyWrapper(core_python_state.PythonState):
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()

View File

@ -21,19 +21,17 @@ from __future__ import print_function
import collections
import itertools
import os
import unittest
from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework.test_util import TensorFlowTestCase
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
@ -44,7 +42,6 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as saver_lib
@ -339,7 +336,7 @@ def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs):
return [dict(t) for t in {tuple(d.items()) for d in res}]
class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _test_training_helper(self,
num_units,
@ -382,12 +379,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_training(self, num_units, input_size, batch_size, time, num_layers,
variable_seq_lengths, time_major, dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_training_helper(
num_units,
input_size,
@ -406,13 +400,10 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_training_fp16(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_training_helper(
num_units,
input_size,
@ -433,12 +424,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
variable_seq_lengths, time_major, dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
sess,
@ -465,13 +453,10 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_inference_fp16(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
sess,
@ -502,14 +487,11 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
"""Validates that dropout does not affect Cudnn Rnn inference."""
if not context.context().num_gpus():
self.skipTest("No GPUs found")
# Hand-picked dropouts are used below (0. and 1.)
with ops.Graph().as_default() as g:
with self.session(use_gpu=True, graph=g) as sess:
@ -721,7 +703,7 @@ def RunGRU(sess,
return outputs, cu_outputs, h, cu_h
class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
class CudnnGRUTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _test_training_helper(self,
num_units,
@ -764,12 +746,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_training(self, num_units, input_size, batch_size, time, num_layers,
variable_seq_lengths, time_major, dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_training_helper(
num_units,
input_size,
@ -788,13 +767,10 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_training_fp16(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_training_helper(
num_units,
input_size,
@ -815,12 +791,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
variable_seq_lengths, time_major, dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
(outputs, cu_outputs, h, cu_h) = RunGRU(
sess,
@ -843,13 +816,10 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_inference_fp16(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
(outputs, cu_outputs, h, cu_h) = RunGRU(
sess,
@ -875,15 +845,12 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
"""Validates that dropout does not affect Cudnn Rnn inference."""
# Hand-picked dropouts are used below (0. and 1.)
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with ops.Graph().as_default() as g:
with self.session(use_gpu=True, graph=g) as sess:
# 1st time w/o dropout.
@ -919,7 +886,7 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
self.assertAllClose(cu_h[0], cu_h2[0])
class CudnnParamsFormatConverterTest(TensorFlowTestCase,
class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
"""Class for testing various format converters."""
@ -970,22 +937,16 @@ class CudnnParamsFormatConverterTest(TensorFlowTestCase,
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
c["input_size"], c["num_layers"])
for c in NAMED_RNN_TESTCASES)
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_lstm(self, num_units, input_size, num_layers):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_lstm_helper(num_units, input_size, num_layers,
cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
c["input_size"], c["num_layers"])
for c in NAMED_RNN_TESTCASES)
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_lstm_bidi(self, num_units, input_size, num_layers):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_lstm_helper(num_units, input_size, num_layers,
cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
@ -1044,27 +1005,22 @@ class CudnnParamsFormatConverterTest(TensorFlowTestCase,
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
c["input_size"], c["num_layers"])
for c in NAMED_RNN_TESTCASES)
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_gru(self, num_units, input_size, num_layers):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_gru_helper(num_units, input_size, num_layers,
cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
c["input_size"], c["num_layers"])
for c in NAMED_RNN_TESTCASES)
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_gru_bidi(self, num_units, input_size, num_layers):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_gru_helper(num_units, input_size, num_layers,
cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase):
class CudnnRnnSaveRestoreTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
"""Class for testing various Cudnn Rnn SaveableObjects."""
def _create_opaque_param(self,
@ -1112,14 +1068,11 @@ class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase):
],
"direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION]
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_save_restore_variable(self, rnn_mode, num_units, input_size,
num_layers, direction):
# Verify the restored opaque param, once converted to tf_canonical format,
# is the same as the tf canonicals of the pre-restored param.
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
opaque_param = self._create_opaque_param(rnn_mode, num_units, input_size,
num_layers, direction)
@ -1164,14 +1117,11 @@ class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase):
],
"direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION]
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_gpu_only
def test_save_restore_multi_variables(self, rnn_mode, num_units, input_size,
num_layers, direction):
# Verify the restored opaque param, once converted to tf_canonical format,
# is the same as the tf canonicals of the pre-restored param.
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
opaque_params = []
saveables = []

Some files were not shown because too many files have changed in this diff Show More