Merge branch 'master' into doc-patch-batch-normalization
This commit is contained in:
commit
6004a65a77
@ -25,7 +25,7 @@ networks research. The system is general enough to be applicable in a wide
|
||||
variety of other domains, as well.
|
||||
|
||||
TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards
|
||||
compatible API's for C++, Go, Java, JavaScript and Swift.
|
||||
compatible API's for C++, Go, Java, JavaScript, and Swift.
|
||||
|
||||
Keep up to date with release announcements and security updates by
|
||||
subscribing to
|
||||
|
58
WORKSPACE
58
WORKSPACE
@ -43,47 +43,37 @@ remote_config_workspace()
|
||||
# Apple and Swift rules.
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
sha256 = "4b90786009fa8df25230442244bad2832ba8d6bc4987f68150a7de59c8827e90",
|
||||
strip_prefix = "rules_apple-0.14.0",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/archive/0.14.0.tar.gz"],
|
||||
)
|
||||
http_file(
|
||||
name = "xctestrunner",
|
||||
executable = 1,
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
sha256 = "2c62d8cd4ab1e65c08647eb4afe38f51591f43f7f0885e7769832fa137633dcb",
|
||||
strip_prefix = "bazel-skylib-0.7.0",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.7.0.tar.gz"],
|
||||
)
|
||||
|
||||
sha256 = "8f32e2839fba28d549e1670dbed83606dd339a9f7489118e481814d61738270f",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.14.0/rules_apple.0.14.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_apple/releases
|
||||
http_archive(
|
||||
name = "build_bazel_apple_support",
|
||||
sha256 = "835663c4bb02f4bf01dce8a2a176df7fa682dbb867d3698ae12258c1628bb8f0",
|
||||
strip_prefix = "apple_support-0.5.0",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/archive/0.5.0.tar.gz"],
|
||||
)
|
||||
|
||||
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/apple_support/releases
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/bazel-skylib/releases
|
||||
http_archive(
|
||||
name = "build_bazel_rules_swift",
|
||||
sha256 = "32d124878cd49775d84f59ba90440c8b23b7c775aec8fec1978f751c76ddee8a",
|
||||
strip_prefix = "rules_swift-0.7.0",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/archive/0.7.0.tar.gz"],
|
||||
)
|
||||
|
||||
sha256 = "31aad005a9c4e56b256125844ad05eb27c88303502d74138186f9083479f93a6",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.8.0/rules_swift.0.8.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_swift/releases
|
||||
http_archive(
|
||||
name = "com_github_apple_swift_swift_protobuf",
|
||||
type = "zip",
|
||||
strip_prefix = "swift-protobuf-1.2.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.2.0.zip"],
|
||||
)
|
||||
|
||||
# Use swift_rules_dependencies to fetch the tolchains.
|
||||
# Since we defined all the "git_repository" rules above, the following call will
|
||||
# skip redefining them.
|
||||
strip_prefix = "swift-protobuf-1.4.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.4.0.zip"],
|
||||
) # https://github.com/apple/swift-protobuf/releases
|
||||
http_file(
|
||||
name = "xctestrunner",
|
||||
executable = 1,
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
|
||||
) # https://github.com/google/xctestrunner/releases
|
||||
# Use `swift_rules_dependencies` to fetch the toolchains. With the
|
||||
# `git_repository` rules above, the following call will skip redefining them.
|
||||
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
|
||||
swift_rules_dependencies()
|
||||
|
||||
|
@ -799,8 +799,8 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
const auto& op_type = op->operation.Name();
|
||||
auto op_name =
|
||||
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
|
||||
auto* desc =
|
||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
|
||||
std::unique_ptr<TF_OperationDescription> desc(
|
||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()));
|
||||
|
||||
VLOG(1) << "Adding attrs.";
|
||||
tensorflow::AttrValueMap attrs;
|
||||
@ -814,30 +814,42 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
size_t inputIndex = 0;
|
||||
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
|
||||
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
|
||||
// TODO(bgogul): Add support for number attributes.
|
||||
DCHECK(input_arg.number_attr().empty())
|
||||
<< "Number attributes is not implemented yet.";
|
||||
if (input_arg.type_list_attr().empty()) {
|
||||
if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) {
|
||||
auto symbolic_input =
|
||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TF_AddInput(desc, symbolic_input);
|
||||
TF_AddInput(desc.get(), symbolic_input);
|
||||
continue;
|
||||
}
|
||||
const std::string& type_list_attr = input_arg.type_list_attr();
|
||||
const auto& attr_value = attrs[type_list_attr];
|
||||
DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
||||
<< "Type list attribute should be a list!";
|
||||
std::vector<TF_Output> list_inputs(attr_value.list().type_size());
|
||||
size_t list_size = 0;
|
||||
if (!input_arg.type_list_attr().empty()) {
|
||||
const std::string& type_list_attr = input_arg.type_list_attr();
|
||||
const auto& attr_value = attrs[type_list_attr];
|
||||
CHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
||||
<< "Type list attribute should be a list!";
|
||||
list_size = attr_value.list().type_size();
|
||||
} else {
|
||||
CHECK(!input_arg.number_attr().empty());
|
||||
const auto& attr_value = attrs[input_arg.number_attr()];
|
||||
CHECK(attr_value.value_case() == tensorflow::AttrValue::kI)
|
||||
<< "Number attribute should be int!";
|
||||
if (attr_value.i() < 0) {
|
||||
status->status = tensorflow::errors::Internal(
|
||||
"Number attribute for length should be >=0!");
|
||||
return nullptr;
|
||||
}
|
||||
list_size = attr_value.i();
|
||||
}
|
||||
std::vector<TF_Output> list_inputs(list_size);
|
||||
for (TF_Output& list_input : list_inputs) {
|
||||
list_input =
|
||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
TF_AddInputList(desc, list_inputs.data(), list_inputs.size());
|
||||
TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size());
|
||||
}
|
||||
|
||||
auto* graph_op = TF_FinishOperation(desc, status);
|
||||
auto* graph_op = TF_FinishOperation(desc.release(), status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
VLOG(1) << "Op finalized; setting return tensors.";
|
||||
|
@ -376,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
|
||||
TFE_DeleteOp(identityn);
|
||||
}
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) {
|
||||
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrInt(concatv2, "N", 2);
|
||||
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
|
||||
constexpr size_t kNumInputs = 2;
|
||||
for (size_t i = 0; i < kNumInputs; ++i) {
|
||||
TFE_OpAddInput(concatv2, matrix, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
}
|
||||
TFE_OpAddInput(concatv2, axis, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
AddEagerOpToGraphAndCheck(
|
||||
concatv2, [this, kNumInputs](TF_Operation* graph_op) {
|
||||
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1);
|
||||
int64_t attrN;
|
||||
TF_OperationGetAttrInt(graph_op, "N", &attrN, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
EXPECT_EQ(attrN, kNumInputs);
|
||||
EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_),
|
||||
kNumInputs);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
});
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
TFE_DeleteTensorHandle(matrix);
|
||||
TFE_DeleteOp(concatv2);
|
||||
}
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest,
|
||||
GeneratesInternalErrorsForInvalidNumberAttributes) {
|
||||
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
int num_retvals = 5;
|
||||
TFE_TensorHandle* retvals[5];
|
||||
|
||||
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrInt(concatv2, "N", -1);
|
||||
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
|
||||
|
||||
TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals,
|
||||
&num_retvals, status_);
|
||||
EXPECT_EQ(graph_op, nullptr);
|
||||
EXPECT_EQ(status_->status.error_message(),
|
||||
"Number attribute for length should be >=0!");
|
||||
|
||||
TFE_DeleteOp(concatv2);
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
TFE_DeleteTensorHandle(matrix);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -143,7 +143,9 @@ tensorflow::Status CreateRemoteContexts(
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
|
||||
tensorflow::eager::EagerClient* eager_client;
|
||||
TF_RETURN_IF_ERROR(
|
||||
remote_eager_workers->GetClient(remote_worker, &eager_client));
|
||||
if (eager_client == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
|
@ -80,11 +80,14 @@ void ExecuteWithProfiling(bool async) {
|
||||
profiler_result->length}));
|
||||
string profile_proto_str = profile_proto.DebugString();
|
||||
if (!gpu_device_name.empty()) {
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0"));
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
|
||||
// device name with "stream:all" is collected by Device Tracer.
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
|
||||
}
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0"));
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:CPU:0"));
|
||||
// This is collected by TraceMe
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
|
||||
TF_DeleteBuffer(profiler_result);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function as _print_function
|
||||
import os as _os
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
|
@ -732,44 +732,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_fusion_optimizer",
|
||||
srcs = ["xla_fusion_optimizer.cc"],
|
||||
hdrs = ["xla_fusion_optimizer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":common",
|
||||
":compilation_passes",
|
||||
":union_find",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "xla_fusion_optimizer_test",
|
||||
srcs = ["xla_fusion_optimizer_test.cc"],
|
||||
deps = [
|
||||
":common",
|
||||
":xla_cluster_util",
|
||||
":xla_fusion_optimizer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:resource_variable_ops",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "node_matchers",
|
||||
testonly = True,
|
||||
|
@ -441,19 +441,48 @@ class MarkForCompilationPassImpl {
|
||||
bool has_xla_compile_attr;
|
||||
};
|
||||
|
||||
// Nodes that XLA can compile are put in `candidates`. Nodes put in
|
||||
// `isolated_nodes` must either be unclustered or be put in trivial
|
||||
// ---------------------------------------------------------------------------
|
||||
// The pass proceeds in four steps, out of which `RunEdgeContractionLoop` and
|
||||
// `CreateClusters` do most of the heavy lifting.
|
||||
|
||||
// Initialize some internal data structures.
|
||||
Status Initialize();
|
||||
|
||||
// Contracts as many edges as possible to create XLA clusters. After this
|
||||
// finishes the clustering decisions made are implicitly stored in
|
||||
// `clusters_`.
|
||||
Status RunEdgeContractionLoop();
|
||||
|
||||
// Manifests the clustering decisions into the TF graph by tagging nodes with
|
||||
// an `_XlaCluster` attribute. Also some basic filter logic, like
|
||||
// tf_xla_min_cluster_size, are applied here.
|
||||
Status CreateClusters();
|
||||
|
||||
Status DumpDebugInfo();
|
||||
|
||||
bool IsCompilationCandidate(Node* n) const {
|
||||
return compilation_candidates_.find(n) != compilation_candidates_.end();
|
||||
}
|
||||
|
||||
// Tries to contract the edge from cluster `from` to cluster `to`. Returns
|
||||
// true if successful.
|
||||
StatusOr<bool> TryToContractEdge(const Cluster& from, int to);
|
||||
|
||||
// Tries to contract each edge from `cluster_from`. Returns true as soon as a
|
||||
// single edge contraction is successful. Returns true if no edges were
|
||||
// contracted.
|
||||
StatusOr<bool> TryToContractEdgeFrom(Cluster* cluster_from);
|
||||
|
||||
// Nodes that XLA can compile are put in `candidates_`. Nodes put in
|
||||
// `isolated_nodes_` must either be unclustered or be put in trivial
|
||||
// single-node clusters.
|
||||
StatusOr<std::pair<OrderedNodeSet, absl::flat_hash_set<Node*>>>
|
||||
FindCompilationCandidates();
|
||||
Status FindCompilationCandidates();
|
||||
|
||||
bool CompilationDisallowedByXlaCompileAttr(Node* node,
|
||||
const DeviceType& jit_device_type);
|
||||
|
||||
Status BuildInitialClusterSet(const OrderedNodeSet& compilation_candidates,
|
||||
const DeadnessAnalysis* deadness_analysis,
|
||||
std::vector<UnionFind<Cluster>>* clusters,
|
||||
std::deque<UnionFind<Cluster>*>* worklist);
|
||||
// Populates `clusters_` and `worklist_`.
|
||||
Status BuildInitialClusterSet();
|
||||
|
||||
StatusOr<bool> ShouldCompileClusterImpl(const Cluster& cluster);
|
||||
|
||||
@ -461,9 +490,7 @@ class MarkForCompilationPassImpl {
|
||||
|
||||
bool HasMismatchingXlaScope(Node* node_from, Node* node_to);
|
||||
|
||||
StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
|
||||
int to_node_id, const OrderedNodeSet& compilation_candidates,
|
||||
absl::Span<UnionFind<Cluster>> clusters, const GraphCycles& cycles);
|
||||
StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(int to);
|
||||
|
||||
// Returns true if the devices in `cluster_a` and `cluster_b` are compatible
|
||||
// and therefore not a hindrance for combining the two clusters into a larger
|
||||
@ -481,13 +508,217 @@ class MarkForCompilationPassImpl {
|
||||
OptimizerOptions::GlobalJitLevel global_jit_level_;
|
||||
absl::flat_hash_map<int, bool> should_compile_cluster_cache_;
|
||||
DeviceInfoCache device_info_cache_;
|
||||
|
||||
bool initialized_ = false;
|
||||
bool edges_contracted_ = false;
|
||||
bool clusters_created_ = false;
|
||||
|
||||
std::vector<UnionFind<Cluster>> clusters_;
|
||||
std::deque<UnionFind<Cluster>*> worklist_;
|
||||
GraphCycles graph_cycles_;
|
||||
OrderedNodeSet compilation_candidates_;
|
||||
absl::flat_hash_set<Node*> isolated_nodes_;
|
||||
std::unique_ptr<DeadnessAnalysis> deadness_analysis_;
|
||||
int64 iteration_count_ = 0;
|
||||
};
|
||||
|
||||
Status IgnoreResourceOpForSafetyAnalysis(DeviceInfoCache* device_info_cache,
|
||||
const Node& n, bool* ignore) {
|
||||
// If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
|
||||
// ignore it during resource operation safety analysis. We need this hack
|
||||
// because of two reasons:
|
||||
//
|
||||
// 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
|
||||
// 2. We don't support live-out values of type DT_RESOURCE and live-in values
|
||||
// of type DT_RESOURCE that are not resource variables.
|
||||
//
|
||||
// Together these imply we cannot let resource variable safety analysis
|
||||
// constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
|
||||
// clusters: both of them will have to be clustered because of (1) and we
|
||||
// won't be able to keep the edge between the two as neither the input to the
|
||||
// second XLA cluster nor the output from the first XLA cluster are supported
|
||||
// because of (2).
|
||||
//
|
||||
// TODO(b/113100872): This can be fixed if the TensorFlow representation for
|
||||
// TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
|
||||
// (2) would no longer hold.
|
||||
|
||||
if (n.assigned_device_name().empty()) {
|
||||
*ignore = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const XlaOpRegistry::DeviceRegistration* registration,
|
||||
device_info_cache->GetCompilationDevice(n.assigned_device_name()));
|
||||
|
||||
if (!registration) {
|
||||
*ignore = true;
|
||||
} else {
|
||||
*ignore = registration->cluster_resource_variable_ops_unsafely;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::Initialize() {
|
||||
TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_);
|
||||
initialized_ = true;
|
||||
|
||||
TF_RETURN_IF_ERROR(FindCompilationCandidates());
|
||||
|
||||
if (compilation_candidates_.empty()) {
|
||||
VLOG(2) << "No compilable candidates";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
|
||||
CreateCycleDetectionGraph(graph_, &graph_cycles_));
|
||||
if (!cycle_detection_graph_ok) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
|
||||
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
|
||||
};
|
||||
|
||||
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
|
||||
graph_, flib_def_, ignore_resource_ops, &graph_cycles_));
|
||||
|
||||
if (!debug_options_.ignore_deadness_checks) {
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
|
||||
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_));
|
||||
}
|
||||
|
||||
// Each compilation candidate belongs to a cluster. The cluster's
|
||||
// representative names the node in the 'cycles' graph that represents the
|
||||
// cluster.
|
||||
return BuildInitialClusterSet();
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
|
||||
TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_);
|
||||
edges_contracted_ = true;
|
||||
|
||||
// TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
|
||||
// example, from the Grappler fusion pass).
|
||||
while (!worklist_.empty()) {
|
||||
UnionFind<Cluster>* cluster_from = worklist_.front();
|
||||
worklist_.pop_front();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool contracted_one_edge,
|
||||
TryToContractEdgeFrom(&cluster_from->Get()));
|
||||
|
||||
if (contracted_one_edge) {
|
||||
worklist_.push_back(cluster_from);
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << iteration_count_ << " iterations in inner loop for graph with "
|
||||
<< compilation_candidates_.size()
|
||||
<< " compilation candidates. Iterations per compilation candidate: "
|
||||
<< ((1.0 * iteration_count_) / compilation_candidates_.size());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::CreateClusters() {
|
||||
TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
|
||||
clusters_created_ = true;
|
||||
|
||||
static std::atomic<int64> cluster_sequence_num;
|
||||
|
||||
// Count the number of non-trivial elements in each cluster.
|
||||
std::vector<int> effective_cluster_sizes(graph_->num_node_ids());
|
||||
|
||||
// has_functional_control_flow remembers if a cluster contains a functional
|
||||
// control flow node.
|
||||
std::vector<bool> has_functional_control_flow(graph_->num_node_ids());
|
||||
|
||||
for (const Node* n : compilation_candidates_) {
|
||||
int cluster = clusters_[n->id()].Get().representative;
|
||||
// We want clusters to be big enough that the benefit from XLA's
|
||||
// optimizations offsets XLA related overhead (for instance we add some
|
||||
// Switch/Merge nodes into the graph to implement lazy compilation). To
|
||||
// this end, we don't count Identity and Constant nodes because they do not
|
||||
// enable interesting optimizations by themselves.
|
||||
if (!n->IsIdentity() && !n->IsConstant()) {
|
||||
effective_cluster_sizes[cluster]++;
|
||||
}
|
||||
if (n->type_string() == "While" || n->type_string() == "If") {
|
||||
has_functional_control_flow[cluster] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Names for each cluster.
|
||||
std::unordered_map<int, string> cluster_names;
|
||||
|
||||
if (debug_options_.dump_graphs) {
|
||||
DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
|
||||
}
|
||||
|
||||
// Mark clusters for compilation that:
|
||||
// * are placed on a device that requires compilation (an XlaDevice),
|
||||
// * are explicitly marked for compilation (_XlaCompile=true), or
|
||||
// * have more than debug_options_.xla_min_cluster_size elements (applicable
|
||||
// only if compilation is enabled, otherwise there will be no such
|
||||
// candidates).
|
||||
for (Node* n : compilation_candidates_) {
|
||||
const Cluster& cluster = clusters_[n->id()].Get();
|
||||
TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
|
||||
ShouldCompileCluster(cluster));
|
||||
if (!should_compile_cluster) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cluster_repr = cluster.representative;
|
||||
|
||||
// Compile if the user marked this node _XlaCompile=true
|
||||
bool compile_attr = false;
|
||||
bool marked_for_compilation = false;
|
||||
if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
|
||||
marked_for_compilation = compile_attr;
|
||||
} else if (flib_def_->GetAttr(*n, kXlaCompileAttr, &compile_attr).ok()) {
|
||||
marked_for_compilation = compile_attr;
|
||||
}
|
||||
|
||||
// We assume that functional If and While nodes have at least
|
||||
// min_cluster_size non-trivial nodes in them. It would be more principled
|
||||
// to (recursively) verify this fact, but that's probably not worth the
|
||||
// trouble.
|
||||
|
||||
if (effective_cluster_sizes[cluster_repr] >=
|
||||
debug_options_.min_cluster_size ||
|
||||
has_functional_control_flow[cluster_repr] || marked_for_compilation) {
|
||||
string& name = cluster_names[cluster_repr];
|
||||
|
||||
if (name.empty()) {
|
||||
name = absl::StrCat("cluster_", cluster_sequence_num++);
|
||||
}
|
||||
n->AddAttr(kXlaClusterAttr, name);
|
||||
n->AddAttr(kXlaAlreadyClustered, true);
|
||||
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::DumpDebugInfo() {
|
||||
TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_);
|
||||
|
||||
if (debug_options_.dump_graphs) {
|
||||
DumpPostClusteringGraphs();
|
||||
}
|
||||
|
||||
VLogClusteringSummary();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool>
|
||||
MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
|
||||
int to_node_id, const OrderedNodeSet& compilation_candidates,
|
||||
absl::Span<UnionFind<Cluster>> clusters, const GraphCycles& cycles) {
|
||||
const Cluster& cluster_to = clusters[to_node_id].Get();
|
||||
int to) {
|
||||
const Cluster& cluster_to = clusters_[to].Get();
|
||||
|
||||
// If any of the consumer's producers are on a different device, do not
|
||||
// cluster these nodes. This prevents other work on this device from being
|
||||
@ -499,14 +730,14 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
|
||||
//
|
||||
// TODO(b/117085735): We probably want to handle the reciprocal of this case
|
||||
// where a cluster is producing data for multiple devices.
|
||||
for (const auto& in_id : cycles.Predecessors(to_node_id)) {
|
||||
for (const auto& in_id : graph_cycles_.Predecessors(to)) {
|
||||
if (in_id >= graph_->num_node_ids()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* in = graph_->FindNodeId(in_id);
|
||||
const Cluster& cluster_in = clusters[in_id].Get();
|
||||
if (compilation_candidates.find(in) != compilation_candidates.cend()) {
|
||||
const Cluster& cluster_in = clusters_[in_id].Get();
|
||||
if (IsCompilationCandidate(in)) {
|
||||
TF_ASSIGN_OR_RETURN(bool devices_compatible,
|
||||
AreDevicesCompatible(cluster_to, cluster_in));
|
||||
if (!devices_compatible) {
|
||||
@ -537,20 +768,17 @@ bool MarkForCompilationPassImpl::HasMismatchingXlaScope(Node* node_from,
|
||||
from_scope != to_scope;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::BuildInitialClusterSet(
|
||||
const OrderedNodeSet& compilation_candidates,
|
||||
const DeadnessAnalysis* deadness_analysis,
|
||||
std::vector<UnionFind<Cluster>>* clusters,
|
||||
std::deque<UnionFind<Cluster>*>* worklist) {
|
||||
clusters->resize(graph_->num_node_ids());
|
||||
for (Node* node : compilation_candidates) {
|
||||
Cluster* cluster = &(*clusters)[node->id()].Get();
|
||||
Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
clusters_.resize(graph_->num_node_ids());
|
||||
|
||||
for (Node* node : compilation_candidates_) {
|
||||
Cluster* cluster = &clusters_[node->id()].Get();
|
||||
cluster->representative = node->id();
|
||||
|
||||
if (deadness_analysis) {
|
||||
if (deadness_analysis_) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
cluster->deadness_predicate,
|
||||
deadness_analysis->GetPredicateFor(node, Graph::kControlSlot));
|
||||
deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
|
||||
}
|
||||
|
||||
const string& device = !node->assigned_device_name().empty()
|
||||
@ -572,17 +800,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet(
|
||||
}
|
||||
|
||||
cluster->devices.insert(device);
|
||||
worklist->push_back(&(*clusters)[node->id()]);
|
||||
|
||||
worklist_.push_back(&clusters_[node->id()]);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::pair<OrderedNodeSet, absl::flat_hash_set<Node*>>>
|
||||
MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
OrderedNodeSet candidates;
|
||||
absl::flat_hash_set<Node*> isolated_nodes;
|
||||
|
||||
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, env_, TF_GRAPH_DEF_VERSION,
|
||||
@ -698,17 +923,18 @@ MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
if (!is_tensor_array_or_stack_op) {
|
||||
VLOG(2) << "Isolating " << node->name()
|
||||
<< ": must-be-constant stateful op";
|
||||
isolated_nodes.insert(node);
|
||||
isolated_nodes_.insert(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
candidates.insert(node);
|
||||
compilation_candidates_.insert(node);
|
||||
--(*debug_options_.fuel);
|
||||
}
|
||||
|
||||
VLOG(2) << "candidates->size() = " << candidates.size();
|
||||
return {{candidates, isolated_nodes}};
|
||||
VLOG(2) << "compilation_candidates_.size() = "
|
||||
<< compilation_candidates_.size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
|
||||
@ -753,47 +979,119 @@ bool IsShapeConsumerOp(const Node& node) {
|
||||
node.type_string() == "Size";
|
||||
}
|
||||
|
||||
Status IgnoreResourceOpForSafetyAnalysis(DeviceInfoCache* device_info_cache,
|
||||
const Node& n, bool* ignore) {
|
||||
// If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
|
||||
// ignore it during resource operation safety analysis. We need this hack
|
||||
// because of two reasons:
|
||||
//
|
||||
// 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
|
||||
// 2. We don't support live-out values of type DT_RESOURCE and live-in values
|
||||
// of type DT_RESOURCE that are not resource variables.
|
||||
//
|
||||
// Together these imply we cannot let resource variable safety analysis
|
||||
// constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
|
||||
// clusters: both of them will have to be clustered because of (1) and we
|
||||
// won't be able to keep the edge between the two as neither the input to the
|
||||
// second XLA cluster nor the output from the first XLA cluster are supported
|
||||
// because of (2).
|
||||
//
|
||||
// TODO(b/113100872): This can be fixed if the TensorFlow representation for
|
||||
// TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
|
||||
// (2) would no longer hold.
|
||||
|
||||
if (n.assigned_device_name().empty()) {
|
||||
*ignore = false;
|
||||
return Status::OK();
|
||||
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(
|
||||
const Cluster& cluster_from, int to) {
|
||||
Node* node_to = graph_->FindNodeId(to);
|
||||
if (!IsCompilationCandidate(node_to)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const XlaOpRegistry::DeviceRegistration* registration,
|
||||
device_info_cache->GetCompilationDevice(n.assigned_device_name()));
|
||||
|
||||
if (!registration) {
|
||||
*ignore = true;
|
||||
} else {
|
||||
*ignore = registration->cluster_resource_variable_ops_unsafely;
|
||||
const Cluster& cluster_to = clusters_[to].Get();
|
||||
DCHECK(cluster_from.deadness_predicate.has_value() ==
|
||||
cluster_to.deadness_predicate.has_value());
|
||||
if (cluster_from.deadness_predicate != cluster_to.deadness_predicate) {
|
||||
return false;
|
||||
}
|
||||
return Status::OK();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool devices_compatible,
|
||||
AreDevicesCompatible(cluster_from, cluster_to));
|
||||
if (!devices_compatible) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isolated_nodes_.contains(node_to)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int from = cluster_from.representative;
|
||||
Node* node_from = graph_->FindNodeId(from);
|
||||
|
||||
if (HasMismatchingXlaScope(node_from, node_to)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Ops that consume shapes cannot be the root of a cluster. This is an
|
||||
// optimization.
|
||||
if (clusters_[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't exceed the maximum cluster size.
|
||||
if (clusters_[from].Size() + clusters_[to].Size() >
|
||||
debug_options_.max_cluster_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
|
||||
ClusteringWillIntroduceInterDeviceDependency(to));
|
||||
|
||||
if (will_introduce_cross_device_dependency) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If contracting the edge would create a cycle, bail out. However, just
|
||||
// because we can't merge the clusters now does not mean we won't be able
|
||||
// to merge them in the future. e.g., if we have edges 1->2, 2->3 and
|
||||
// 1->3, we cannot contract edge 1->3. But if we first contract 1->2 then
|
||||
// we can later contract 1->3.
|
||||
return graph_cycles_.ContractEdge(from, to);
|
||||
}
|
||||
|
||||
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdgeFrom(
|
||||
Cluster* cluster_from) {
|
||||
int from = cluster_from->representative;
|
||||
|
||||
Node* node_from = graph_->FindNodeId(from);
|
||||
if (node_from->IsControlFlow()) {
|
||||
// Control flow nodes aren't compilation candidates and should never
|
||||
// appear.
|
||||
return errors::Internal("Found control flow node in clustering worklist: ",
|
||||
node_from->type_string());
|
||||
}
|
||||
|
||||
if (!IsCompilationCandidate(node_from)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isolated_nodes_.count(node_from)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int to : graph_cycles_.Successors(from)) {
|
||||
iteration_count_++;
|
||||
if (to >= graph_->num_node_ids()) {
|
||||
// Node is a fictitious node that is present only in the cycle detection
|
||||
// graph. No clustering is possible.
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool contracted_edge,
|
||||
TryToContractEdge(*cluster_from, to));
|
||||
|
||||
if (!contracted_edge) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Cluster& cluster_to = clusters_[to].Get();
|
||||
|
||||
// Merge the clusters. ContractEdge uses 'from' as the number of the
|
||||
// merged node, so make sure 'from' is the chosen representative.
|
||||
cluster_from->devices.insert(cluster_to.devices.begin(),
|
||||
cluster_to.devices.end());
|
||||
if (!cluster_to.resource_op_device.empty()) {
|
||||
cluster_from->resource_op_device = cluster_to.resource_op_device;
|
||||
}
|
||||
|
||||
cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr;
|
||||
clusters_[from].Merge(&clusters_[to]);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::Run() {
|
||||
static std::atomic<int64> cluster_sequence_num;
|
||||
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
|
||||
@ -801,232 +1099,10 @@ Status MarkForCompilationPassImpl::Run() {
|
||||
// some one-time work.
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1);
|
||||
|
||||
OrderedNodeSet compilation_candidates;
|
||||
absl::flat_hash_set<Node*> isolated_nodes;
|
||||
TF_ASSIGN_OR_RETURN(std::tie(compilation_candidates, isolated_nodes),
|
||||
FindCompilationCandidates());
|
||||
|
||||
if (compilation_candidates.empty()) {
|
||||
VLOG(2) << "No compilable candidates";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
|
||||
CreateCycleDetectionGraph(graph_, &cycles));
|
||||
if (!cycle_detection_graph_ok) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
|
||||
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
|
||||
};
|
||||
|
||||
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
|
||||
graph_, flib_def_, ignore_resource_ops, &cycles));
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> deadness_analysis;
|
||||
if (!debug_options_.ignore_deadness_checks) {
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
|
||||
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis));
|
||||
}
|
||||
|
||||
// Each compilation candidate belongs to a cluster. The cluster's
|
||||
// representative names the node in the 'cycles' graph that represents the
|
||||
// cluster.
|
||||
std::vector<UnionFind<Cluster>> clusters;
|
||||
std::deque<UnionFind<Cluster>*> worklist;
|
||||
TF_RETURN_IF_ERROR(BuildInitialClusterSet(
|
||||
compilation_candidates, deadness_analysis.get(), &clusters, &worklist));
|
||||
|
||||
int64 iteration_count = 0;
|
||||
|
||||
// Repeatedly contract edges between clusters that are on the same device,
|
||||
// provided the contraction would not create a cycle.
|
||||
//
|
||||
// TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
|
||||
// example, from the Grappler fusion pass).
|
||||
while (!worklist.empty()) {
|
||||
Cluster* cluster_from = &worklist.front()->Get();
|
||||
int from = cluster_from->representative;
|
||||
worklist.pop_front();
|
||||
|
||||
Node* node_from = graph_->FindNodeId(from);
|
||||
if (node_from->IsControlFlow()) {
|
||||
// Control flow nodes aren't compilation candidates and should never
|
||||
// appear.
|
||||
return errors::Internal(
|
||||
"Found control flow node in clustering worklist: ",
|
||||
node_from->type_string());
|
||||
}
|
||||
|
||||
if (isolated_nodes.count(node_from)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int to : cycles.Successors(from)) {
|
||||
iteration_count++;
|
||||
if (to >= graph_->num_node_ids()) {
|
||||
// Node is a fictitious node that is present only in the cycle detection
|
||||
// graph. No clustering is possible.
|
||||
continue;
|
||||
}
|
||||
|
||||
const Cluster& cluster_to = clusters[to].Get();
|
||||
Node* node_to = graph_->FindNodeId(to);
|
||||
if (compilation_candidates.find(node_to) ==
|
||||
compilation_candidates.cend()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
DCHECK(cluster_from->deadness_predicate.has_value() ==
|
||||
cluster_to.deadness_predicate.has_value());
|
||||
if (cluster_from->deadness_predicate != cluster_to.deadness_predicate) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool devices_compatible,
|
||||
AreDevicesCompatible(*cluster_from, cluster_to));
|
||||
if (!devices_compatible) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isolated_nodes.count(node_to)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (HasMismatchingXlaScope(node_from, node_to)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ops that consume shapes cannot be the root of a cluster. This is an
|
||||
// optimization.
|
||||
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't exceed the maximum cluster size.
|
||||
if (clusters[from].Size() + clusters[to].Size() >
|
||||
debug_options_.max_cluster_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool will_introduce_cross_device_dependency,
|
||||
ClusteringWillIntroduceInterDeviceDependency(
|
||||
to, compilation_candidates, absl::MakeSpan(clusters), cycles));
|
||||
|
||||
if (will_introduce_cross_device_dependency) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If contracting the edge would create a cycle, bail out. However, just
|
||||
// because we can't merge the clusters now does not mean we won't be able
|
||||
// to merge them in the future. e.g., if we have edges 1->2, 2->3 and
|
||||
// 1->3, we cannot contract edge 1->3. But if we first contract 1->2 then
|
||||
// we can later contract 1->3.
|
||||
if (!cycles.ContractEdge(from, to)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Merge the clusters. ContractEdge uses 'from' as the number of the
|
||||
// merged node, so make sure 'from' is the chosen representative.
|
||||
cluster_from->devices.insert(cluster_to.devices.begin(),
|
||||
cluster_to.devices.end());
|
||||
if (!cluster_to.resource_op_device.empty()) {
|
||||
cluster_from->resource_op_device = cluster_to.resource_op_device;
|
||||
}
|
||||
cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr;
|
||||
clusters[from].Merge(&clusters[to]);
|
||||
|
||||
worklist.push_back(&clusters[from]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << iteration_count << " iterations in inner loop for graph with "
|
||||
<< compilation_candidates.size()
|
||||
<< " compilation candidates. Iterations per compilation candidate: "
|
||||
<< ((1.0 * iteration_count) / compilation_candidates.size());
|
||||
|
||||
// Count the number of non-trivial elements in each cluster.
|
||||
std::vector<int> effective_cluster_sizes(graph_->num_node_ids());
|
||||
|
||||
// has_functional_control_flow remembers if a cluster contains a functional
|
||||
// control flow node.
|
||||
std::vector<bool> has_functional_control_flow(graph_->num_node_ids());
|
||||
|
||||
for (const Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].Get().representative;
|
||||
// We want clusters to be big enough that the benefit from XLA's
|
||||
// optimizations offsets XLA related overhead (for instance we add some
|
||||
// Switch/Merge nodes into the graph to implement lazy compilation). To
|
||||
// this end, we don't count Identity and Constant nodes because they do not
|
||||
// enable interesting optimizations by themselves.
|
||||
if (!n->IsIdentity() && !n->IsConstant()) {
|
||||
effective_cluster_sizes[cluster]++;
|
||||
}
|
||||
if (n->type_string() == "While" || n->type_string() == "If") {
|
||||
has_functional_control_flow[cluster] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Names for each cluster.
|
||||
std::unordered_map<int, string> cluster_names;
|
||||
|
||||
if (debug_options_.dump_graphs) {
|
||||
DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
|
||||
}
|
||||
|
||||
// Mark clusters for compilation that:
|
||||
// * are placed on a device that requires compilation (an XlaDevice),
|
||||
// * are explicitly marked for compilation (_XlaCompile=true), or
|
||||
// * have more than debug_options_.xla_min_cluster_size elements (applicable
|
||||
// only if compilation is enabled, otherwise there will be no such
|
||||
// candidates).
|
||||
for (Node* n : compilation_candidates) {
|
||||
const Cluster& cluster = clusters[n->id()].Get();
|
||||
TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
|
||||
ShouldCompileCluster(cluster));
|
||||
if (!should_compile_cluster) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cluster_repr = cluster.representative;
|
||||
|
||||
// Compile if the user marked this node _XlaCompile=true
|
||||
bool compile_attr = false;
|
||||
bool marked_for_compilation = false;
|
||||
if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
|
||||
marked_for_compilation = compile_attr;
|
||||
} else if (flib_def_->GetAttr(*n, kXlaCompileAttr, &compile_attr).ok()) {
|
||||
marked_for_compilation = compile_attr;
|
||||
}
|
||||
|
||||
// We assume that functional If and While nodes have at least
|
||||
// min_cluster_size non-trivial nodes in them. It would be more principled
|
||||
// to (recursively) verify this fact, but that's probably not worth the
|
||||
// trouble.
|
||||
|
||||
if (effective_cluster_sizes[cluster_repr] >=
|
||||
debug_options_.min_cluster_size ||
|
||||
has_functional_control_flow[cluster_repr] || marked_for_compilation) {
|
||||
string& name = cluster_names[cluster_repr];
|
||||
|
||||
if (name.empty()) {
|
||||
name = absl::StrCat("cluster_", cluster_sequence_num++);
|
||||
}
|
||||
n->AddAttr(kXlaClusterAttr, name);
|
||||
n->AddAttr(kXlaAlreadyClustered, true);
|
||||
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
|
||||
}
|
||||
}
|
||||
|
||||
if (debug_options_.dump_graphs) {
|
||||
DumpPostClusteringGraphs();
|
||||
}
|
||||
|
||||
VLogClusteringSummary();
|
||||
TF_RETURN_IF_ERROR(Initialize());
|
||||
TF_RETURN_IF_ERROR(RunEdgeContractionLoop());
|
||||
TF_RETURN_IF_ERROR(CreateClusters());
|
||||
TF_RETURN_IF_ERROR(DumpDebugInfo());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -519,7 +519,7 @@ Status XlaDevice::RefreshStatus() {
|
||||
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||
const char* jit_device) {
|
||||
// Any op assigned to the device that isn't rewritten by the graph rewriter
|
||||
// gets executed by a n XlaCompileOnDemandOp, which compiles it and executes
|
||||
// gets executed by an XlaCompileOnDemandOp, which compiles it and executes
|
||||
// it just-in-time.
|
||||
OpKernel* (*factory)(OpKernelConstruction*) =
|
||||
[](OpKernelConstruction* context) -> OpKernel* {
|
||||
|
@ -247,6 +247,9 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
data::MakeIteratorOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
||||
data::IteratorGetNextOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
|
||||
|
@ -1,352 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Is 'node' an operator that consumes only the shape of its input, not the
|
||||
// data itself?
|
||||
static bool IsShapeConsumerOp(const Node& node) {
|
||||
return node.type_string() == "Shape" || node.type_string() == "ShapeN" ||
|
||||
node.type_string() == "Rank" || node.type_string() == "Size";
|
||||
}
|
||||
|
||||
// Returns true if the op can be decomposed into XLA ops for which
|
||||
// there are fusible elemental implementations.
|
||||
static bool IsXlaFusible(const NodeDef& node) {
|
||||
static const std::unordered_set<std::string>* elementwise_ops =
|
||||
new std::unordered_set<std::string>(
|
||||
{// tf2xla/kernels/aggregate_ops.cc
|
||||
"AddN",
|
||||
// tf2xla/kernels/binary_ops.cc
|
||||
"Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
|
||||
"FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
|
||||
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
|
||||
"TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
|
||||
"GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
|
||||
"SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
|
||||
// tf2xla/kernels/unary_ops.cc
|
||||
"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
|
||||
"Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
|
||||
"Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
|
||||
"Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
|
||||
"Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
|
||||
"Square", "Tan", "Tanh", "Real", "Imag",
|
||||
// tf2xla/kernels/bcast_ops.cc
|
||||
"BroadcastArgs", "BroadcastGradientArgs",
|
||||
// tf2xla/kernels/bias_ops.cc
|
||||
"BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
|
||||
// tf2xla/kernels/cast_op.cc
|
||||
"Cast",
|
||||
// tf2xla/kernels/concat_op.cc
|
||||
"Concat", "ConcatV2", "ConcatOffset",
|
||||
// tf2xla/kernels/const_op.cc
|
||||
"Const",
|
||||
// tf2xla/kernels/elu_op.cc
|
||||
"Elu", "EluGrad", "Selu", "SeluGrad",
|
||||
// tf2xla/kernels/fill_op.cc
|
||||
"Fill",
|
||||
// tf2xla/kernels/identity_op.cc
|
||||
"Identity", "IdentityN", "PreventGradient",
|
||||
"StopGradient", /*"Snapshot",*/
|
||||
// tf2xla/kernels/index_ops.cc
|
||||
"ArgMax", "ArgMin",
|
||||
// tf2xla/kernels/mirror_pad_op.cc
|
||||
"MirrorPad",
|
||||
// tf2xla/kernels/one_hot_op.cc
|
||||
"OneHot",
|
||||
// tf2xla/kernels/pack_op.cc
|
||||
"Pack",
|
||||
// tf2xla/kernels/pad_op.cc
|
||||
"Pad", "PadV2",
|
||||
// tf2xla/kernels/relu_op.cc
|
||||
"Relu", "Relu6", "ReluGrad", "Relu6Grad",
|
||||
// tf2xla/kernels/reshape_op.cc
|
||||
"Reshape",
|
||||
// tf2xla/kernels/reverse_op.cc
|
||||
"Reverse", "ReverseV2",
|
||||
// tf2xla/kernels/reverse_sequence_op.cc
|
||||
"ReverseSequence",
|
||||
// tf2xla/kernels/shape_op.cc
|
||||
"Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
|
||||
"ZerosLike", "OnesLike",
|
||||
// tf2xla/kernels/slice_op.cc
|
||||
"Slice",
|
||||
// tf2xla/kernels/split_op.cc
|
||||
"Split", "SplitV",
|
||||
// tf2xla/kernels/strided_slice_op.cc
|
||||
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
|
||||
// tf2xla/kernels/tile_ops.cc
|
||||
"Tile",
|
||||
// tf2xla/kernels/transpose_op.cc
|
||||
"Transpose", "InvertPermutation",
|
||||
// tf2xla/kernels/unpack_op.cc
|
||||
"Unpack"});
|
||||
|
||||
return elementwise_ops->count(node.op()) > 0;
|
||||
}
|
||||
|
||||
Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
|
||||
const grappler::GrapplerItem& item,
|
||||
GraphDef* output) {
|
||||
VLOG(2) << "Here at fusion optimizer";
|
||||
|
||||
// TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op.
|
||||
// Once that happens, the expected interaction between this optimizer and when
|
||||
// the global_jit_level is set is as follows: Fusion optimizer will replace
|
||||
// appropriate fusion clusters with XlaLaunch nodes. The remaining graph can
|
||||
// be further compiled where possible via mark_for_compilation_pass. Note that
|
||||
// this might lead to inefficient clustering, and it is best to use either the
|
||||
// fusion optimizer or the global_jit flag, and not combine the two.
|
||||
|
||||
// Create a Graph out of GraphDef. This is required currently because the
|
||||
// helpers around clustering, encapsulation etc work on graphs.
|
||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||
item.graph.library());
|
||||
Graph graph(function_library);
|
||||
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
|
||||
shape_refiner.set_require_shape_inference_fns(false);
|
||||
shape_refiner.set_disable_constant_propagation(true);
|
||||
ImportGraphDefOptions options;
|
||||
// Graph optimization happens at the late stage of graph execution, when
|
||||
// colocation constraints are already validated previously and the device
|
||||
// placement of nodes has also completed, so there is no need to validate
|
||||
// colocation constraints again.
|
||||
options.validate_colocation_constraints = false;
|
||||
options.validate_shape = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
|
||||
|
||||
// Collect nodes that can be fused via XLA, while ignoring those that
|
||||
// explicitly ask for XLA: (*) nodes that are marked to be compiled
|
||||
// explicitly. (*) nodes assigned to XLA device.
|
||||
OrderedNodeSet compilation_candidates;
|
||||
for (Node* node : graph.op_nodes()) {
|
||||
// If there is a _XlaCompile annotation, ignore the node if it is
|
||||
// true. Nodes are marked with this attr via experimental_jit_scope, and
|
||||
// will be handled by the mark_for_compilation pass.
|
||||
bool compile = false;
|
||||
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
|
||||
if (status.ok() && compile) {
|
||||
continue;
|
||||
}
|
||||
// If there is already a _XlaCluster annotation, ignore the node. Nodes are
|
||||
// marked with this attr to indicate they are already part of a cluster and
|
||||
// hence ignored.
|
||||
status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile);
|
||||
if (status.ok()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If there is an explicit XLA device placement, ignore the node.
|
||||
DeviceType device_type("");
|
||||
TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
|
||||
if (device_type.type_string().find("XLA") != string::npos) continue;
|
||||
|
||||
// Assume all fusible ops are registered.
|
||||
// TODO(hpucha): Check for registration if possible.
|
||||
if (!IsXlaFusible(node->def())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// XLA does not offer guaranteed aliasing between the input and output of
|
||||
// the XLA cluster so it can't implement the forward-tensor-ref semantic.
|
||||
// Leave such nodes out of XLA clusters.
|
||||
if (HasForwardedRefInput(*node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
compilation_candidates.insert(node);
|
||||
}
|
||||
|
||||
if (compilation_candidates.empty()) {
|
||||
VLOG(2) << "No compilable candidates";
|
||||
*output = item.graph;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
|
||||
CreateCycleDetectionGraph(&graph, &cycles));
|
||||
if (!cycle_detection_graph_ok) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
|
||||
&graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));
|
||||
|
||||
// TODO(hpucha): Make clustering more robust. There are two known issues that
|
||||
// we need to mitigate: (a) Non-resource variables can cause deadlocks
|
||||
// when clustering changes order of execution. See b/77263461 for a specific
|
||||
// example. (b) Queue operations can also cause deadlocks. See b/77261498 for
|
||||
// example.
|
||||
|
||||
struct Cluster {
|
||||
// Identifies the node that represents this cluster in the cycle detection
|
||||
// graph.
|
||||
int representative = -1;
|
||||
};
|
||||
|
||||
// Each compilation candidate belongs to a cluster. The cluster's
|
||||
// representative names the node in the 'cycles' graph that represents the
|
||||
// cluster.
|
||||
std::vector<UnionFind<Cluster>> clusters(graph.num_node_ids());
|
||||
std::deque<UnionFind<Cluster>*> worklist;
|
||||
for (Node* node : compilation_candidates) {
|
||||
Cluster& cluster = clusters[node->id()].Get();
|
||||
cluster.representative = node->id();
|
||||
worklist.push_back(&clusters[node->id()]);
|
||||
}
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> deadness_analysis;
|
||||
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness_analysis));
|
||||
|
||||
// Repeatedly contract edges between clusters that are on the same device,
|
||||
// provided the contraction would not create a cycle. This is a simplified
|
||||
// version of the clustering in mark_for_compilation_pass that also deals with
|
||||
// nodes that are explicitly tagged to be compiled/clustered.
|
||||
while (!worklist.empty()) {
|
||||
int from = worklist.front()->Get().representative;
|
||||
worklist.pop_front();
|
||||
|
||||
Node* node_from = graph.FindNodeId(from);
|
||||
if (node_from->IsControlFlow()) {
|
||||
// Control flow nodes aren't compilation candidates and should never
|
||||
// appear.
|
||||
return errors::Internal(
|
||||
"Found control flow node in clustering worklist: ",
|
||||
node_from->type_string());
|
||||
}
|
||||
for (int to : cycles.Successors(from)) {
|
||||
if (to >= graph.num_node_ids()) {
|
||||
// Node is a "frame" node that is present only in the cycle detection
|
||||
// graph. No clustering is possible.
|
||||
continue;
|
||||
}
|
||||
Node* node_to = graph.FindNodeId(to);
|
||||
if (compilation_candidates.find(node_to) ==
|
||||
compilation_candidates.cend()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do not cluster across devices.
|
||||
if (node_from->def().device() != node_to->def().device()) {
|
||||
VLOG(2) << "Devices " << node_from->def().device() << " "
|
||||
<< node_to->def().device();
|
||||
VLOG(2) << "Device names " << node_from->assigned_device_name() << " "
|
||||
<< node_to->assigned_device_name();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ops that consume shapes cannot be the root of a cluster. This is an
|
||||
// optimization.
|
||||
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeadnessAnalysis::DeadnessPredicate pred_from,
|
||||
deadness_analysis->GetPredicateFor(node_from, Graph::kControlSlot));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeadnessAnalysis::DeadnessPredicate pred_to,
|
||||
deadness_analysis->GetPredicateFor(node_to, Graph::kControlSlot));
|
||||
|
||||
if (pred_from != pred_to) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If contracting the edge would create a cycle, bail out.
|
||||
// However, just because we can't merge the clusters now does not mean
|
||||
// we won't be able to merge them in the future.
|
||||
// e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
|
||||
// 1->3. But if we first contract 1->2 then we can later contract 1->3.
|
||||
if (!cycles.ContractEdge(from, to)) continue;
|
||||
|
||||
// Merge the clusters. ContractEdge uses 'from' as the number of the
|
||||
// merged node, so make sure 'from' is the chosen representative.
|
||||
clusters[from].Merge(&clusters[to]);
|
||||
|
||||
worklist.push_back(&clusters[from]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Count the number of non-trivial elements in each cluster.
|
||||
std::vector<int> effective_cluster_sizes(graph.num_node_ids());
|
||||
for (const Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].Get().representative;
|
||||
// Identity nodes will be removed if the node gets marked for compilation.
|
||||
// Therefore we don't want to count them towards the effective cluster size.
|
||||
if (n->def().op() != "Identity") {
|
||||
effective_cluster_sizes[cluster]++;
|
||||
}
|
||||
}
|
||||
|
||||
const int min_cluster_size = 2;
|
||||
int num_clusters = 0;
|
||||
for (auto size : effective_cluster_sizes) {
|
||||
if (size >= min_cluster_size) {
|
||||
VLOG(3) << "Cluster " << num_clusters << " " << size;
|
||||
num_clusters++;
|
||||
}
|
||||
}
|
||||
|
||||
// Names for each cluster.
|
||||
std::unordered_map<int, string> cluster_names;
|
||||
// Sequence number generator to ensure clusters have unique names.
|
||||
static std::atomic<int64> cluster_sequence_num;
|
||||
|
||||
for (Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].Get().representative;
|
||||
|
||||
// Compile if this is a cluster of >= min_cluster_size compilable operators.
|
||||
if (effective_cluster_sizes[cluster] >= min_cluster_size) {
|
||||
string& name = cluster_names[cluster];
|
||||
|
||||
if (name.empty()) {
|
||||
name = absl::StrCat("cluster_", cluster_sequence_num++);
|
||||
}
|
||||
n->AddAttr(kXlaClusterAttr, name);
|
||||
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
|
||||
}
|
||||
}
|
||||
|
||||
graph.ToGraphDef(output);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion");
|
||||
|
||||
} // namespace tensorflow
|
@ -1,49 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Optimizes graphs by fusing ops where possible, resulting in more efficient
|
||||
// execution.
|
||||
class XlaFusionOptimizer : public grappler::CustomGraphOptimizer {
|
||||
public:
|
||||
XlaFusionOptimizer() {}
|
||||
~XlaFusionOptimizer() override {}
|
||||
|
||||
Status Init(
|
||||
const RewriterConfig_CustomGraphOptimizer* config = nullptr) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string name() const override { return "xla-fusion"; };
|
||||
|
||||
Status Optimize(grappler::Cluster* cluster,
|
||||
const grappler::GrapplerItem& item,
|
||||
GraphDef* output) override;
|
||||
|
||||
void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override {
|
||||
// Nothing to do for XlaFusionOptimizer.
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
|
@ -1,208 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
REGISTER_OP("UncompilableNullary").Output("o: float");
|
||||
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
|
||||
|
||||
class XlaFusionOptimizerTest : public grappler::GrapplerTest {
|
||||
protected:
|
||||
std::unordered_map<string, string> GetClusters(const GraphDef& graph) {
|
||||
std::unordered_map<string, string> ids;
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
string cluster;
|
||||
if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) {
|
||||
CHECK(!cluster.empty());
|
||||
ids[node.name()] = cluster;
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, Chains) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a =
|
||||
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
|
||||
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
|
||||
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
|
||||
Node* d =
|
||||
ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
|
||||
Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
|
||||
ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_EQ(4, clusters.size());
|
||||
EXPECT_EQ(clusters["B"], clusters["C"]);
|
||||
EXPECT_EQ(clusters["E"], clusters["F"]);
|
||||
EXPECT_NE(clusters["B"], clusters["E"]);
|
||||
EXPECT_TRUE(clusters.find("A") == clusters.cend());
|
||||
EXPECT_TRUE(clusters.find("D") == clusters.cend());
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, FusibleOps) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp(
|
||||
"Placeholder",
|
||||
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
|
||||
Node* b = ops::SourceOp(
|
||||
"Placeholder",
|
||||
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
|
||||
|
||||
Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C"));
|
||||
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
|
||||
ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
|
||||
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_EQ(2, clusters.size());
|
||||
EXPECT_EQ(clusters["C"], clusters["E"]);
|
||||
EXPECT_TRUE(clusters.find("D") == clusters.cend());
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp(
|
||||
"Placeholder",
|
||||
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
|
||||
Node* b = ops::SourceOp(
|
||||
"Placeholder",
|
||||
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
|
||||
|
||||
Node* c = ops::BinaryOp(
|
||||
"Add", a, b,
|
||||
builder.opts().WithName("C").WithDevice("/device:XLA_CPU"));
|
||||
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
|
||||
Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
|
||||
ops::UnaryOp("Cos", e,
|
||||
builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true));
|
||||
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, UncompilableCycles) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b =
|
||||
ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
|
||||
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_EQ(3, clusters.size());
|
||||
EXPECT_EQ(clusters["A"], clusters["B"]);
|
||||
EXPECT_EQ(clusters["A"], clusters["C"]);
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output var_handle =
|
||||
ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({}));
|
||||
Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f);
|
||||
Output begin = ops::Const(root.WithOpName("begin"), 0);
|
||||
Output end = ops::Const(root.WithOpName("end"), 1);
|
||||
Output strides = ops::Const(root.WithOpName("strides"), 1);
|
||||
ops::ResourceStridedSliceAssign assign_1(
|
||||
root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign);
|
||||
ops::ResourceStridedSliceAssign assign_2(
|
||||
root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign);
|
||||
root.graph()->AddControlEdge(assign_1.operation.node(),
|
||||
assign_2.operation.node());
|
||||
grappler::GrapplerItem item;
|
||||
root.graph()->ToGraphDef(&item.graph);
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_NE(clusters["assign_1"], clusters["assign_2"]);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -23,6 +23,7 @@ import itertools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -1041,6 +1042,62 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([2], dtype=np.int64),
|
||||
expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
|
||||
|
||||
def testBatchMatMulBroadcast(self):
|
||||
"""Tests broadcasting behavior of BatchMatMul."""
|
||||
with compat.forward_compatibility_horizon(2019, 4, 19):
|
||||
# [2, 3] @ [1, 3, 4] -> [1, 2, 4]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([[10, 20, 30], [11, 21, 31]], dtype=np.float32),
|
||||
np.array([[[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]]],
|
||||
dtype=np.float32),
|
||||
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
|
||||
dtype=np.float32))
|
||||
# [1, 2, 3] @ [3, 4] -> [1, 2, 4]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([[[10, 20, 30], [11, 21, 31]]], dtype=np.float32),
|
||||
np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]],
|
||||
dtype=np.float32),
|
||||
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
|
||||
dtype=np.float32))
|
||||
# [2, 1, 3] @ [3, 1] -> [2, 1, 1]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
|
||||
np.array([[1], [2], [3]], dtype=np.float32),
|
||||
expected=np.array([[[140]], [[146]]], dtype=np.float32))
|
||||
# [2, 1, 3] @ [1, 3] -> [2, 1, 1] (adjoint_b)
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.matmul(x, y, adjoint_b=True),
|
||||
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
|
||||
np.array([[1, 2, 3]], dtype=np.float32),
|
||||
expected=np.array([[[140]], [[146]]], dtype=np.float32))
|
||||
# [2, 3, 1] @ [3, 1] -> [2, 1, 1] (adjoint_a)
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.matmul(x, y, adjoint_a=True),
|
||||
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
|
||||
np.array([[1], [2], [3]], dtype=np.float32),
|
||||
expected=np.array([[[140]], [[146]]], dtype=np.float32))
|
||||
# [2, 3, 1] @ [1, 3] -> [2, 1, 1] (adjoint_a and adjoint_b)
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.matmul(x, y, adjoint_a=True, adjoint_b=True),
|
||||
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
|
||||
np.array([[1, 2, 3]], dtype=np.float32),
|
||||
expected=np.array([[[140]], [[146]]], dtype=np.float32))
|
||||
# [5, 1, 2, 3] @ [1, 7, 3, 4] -> [5, 7, 2, 4]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.ones([5, 1, 2, 3], dtype=np.float32),
|
||||
np.ones([1, 7, 3, 4], dtype=np.float32),
|
||||
expected=np.full([5, 7, 2, 4], 3, dtype=np.float32))
|
||||
# [4, 5, 1, 2, 3] @ [1, 1, 3, 5] -> [4, 5, 1, 2, 5]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.full([4, 5, 1, 2, 3], 2., dtype=np.float32),
|
||||
np.full([1, 1, 3, 5], 3., dtype=np.float32),
|
||||
expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
|
||||
|
||||
def testPad(self):
|
||||
for dtype, pad_type in itertools.product(
|
||||
self.numeric_types, [np.int32, np.int64]):
|
||||
|
@ -387,11 +387,18 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
def fn():
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, tensor_array_name="foo", size=3)
|
||||
return ta.write(-1, np.int32(7)).flow
|
||||
return ta.write(-1, constant_op.constant(7)).flow
|
||||
|
||||
# Test writing the wrong datatype.
|
||||
with self.assertRaisesOpError(
|
||||
"TensorArray dtype is float but op has dtype int32"):
|
||||
# TODO(b/129870929): Remove InvalidArgumentError/second regexp after all
|
||||
# callers provide proper init dtype.
|
||||
with self.assertRaisesRegexp(
|
||||
(ValueError, errors.InvalidArgumentError),
|
||||
r"("
|
||||
r"conversion requested dtype float32 for Tensor with dtype int32"
|
||||
r"|"
|
||||
r"TensorArray dtype is float but op has dtype int32"
|
||||
r")"):
|
||||
xla.compile(fn)[0].eval()
|
||||
|
||||
@test_util.disable_control_flow_v2("b/124334096 verify dtype")
|
||||
|
@ -173,6 +173,7 @@ tf_cuda_library(
|
||||
name = "trt_resources",
|
||||
srcs = [
|
||||
"utils/trt_int8_calibrator.cc",
|
||||
"utils/trt_lru_cache.cc",
|
||||
"utils/trt_resources.cc",
|
||||
],
|
||||
hdrs = [
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
@ -1350,11 +1351,19 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input,
|
||||
// the dims are unknown or need to be inferred. And we don't do further checks
|
||||
// but rely on the caller to not make mistakes.
|
||||
// Otherwise we do simple check to make sure the total sizes are the same.
|
||||
if (AreDimsStaticWithDifferentSize(input_dims, dims, input.is_tensor())) {
|
||||
// If an input is a weight, it is going to become a tensor via
|
||||
// CreateConstantLayer. So we can treat it as a tensor for
|
||||
// AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors.
|
||||
if (AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
|
||||
return errors::InvalidArgument(
|
||||
"Incompatible shapes: ", DebugString(input_dims), " vs. ",
|
||||
DebugString(dims));
|
||||
}
|
||||
// ConstantLayer requires static shapes (cannot infer -1).
|
||||
if (input.is_weights() && !HasStaticShape(dims)) {
|
||||
return errors::InvalidArgument("Shape is not fully defined: ",
|
||||
DebugString(dims));
|
||||
}
|
||||
if (validation_only) {
|
||||
*tensor = nullptr;
|
||||
return Status::OK();
|
||||
@ -1614,7 +1623,7 @@ struct LambdaFactory {
|
||||
switch (op) {
|
||||
case OP_CATEGORY::RSQRT: {
|
||||
VLOG(2) << "RSQRT GETS DONE";
|
||||
return [](T t) -> T { return 1.0 / sqrt(t); };
|
||||
return [](T t) -> T { return 1.0 / std::sqrt(t); };
|
||||
}
|
||||
case OP_CATEGORY::NEG:
|
||||
return [](T t) -> T { return -t; };
|
||||
@ -1633,7 +1642,7 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
|
||||
case OP_CATEGORY::RSQRT: {
|
||||
VLOG(2) << "RSQRT GETS DONE";
|
||||
return [](Eigen::half t) {
|
||||
return Eigen::half(1.0 / sqrt(static_cast<float>(t)));
|
||||
return Eigen::half(1.0 / std::sqrt(static_cast<float>(t)));
|
||||
};
|
||||
}
|
||||
case OP_CATEGORY::NEG:
|
||||
@ -4237,6 +4246,95 @@ Status ConvertTopK(OpConverterParams* params) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertDepthSpaceShuffle(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
TFAttrs attrs(node_def);
|
||||
const int block_size = attrs.get<int64>("block_size");
|
||||
if (block_size < 2) {
|
||||
return errors::InvalidArgument("Block size must be 2 or greater, at ",
|
||||
node_def.name());
|
||||
}
|
||||
const string data_format = attrs.get<string>("data_format");
|
||||
if (data_format != "NCHW" && data_format != "NHWC") {
|
||||
return errors::Unimplemented("Data format ", data_format,
|
||||
" is not supported, at ", node_def.name());
|
||||
}
|
||||
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
|
||||
if (dims.nbDims != 3) {
|
||||
return errors::InvalidArgument("The input to ", node_def.op(),
|
||||
" must be rank 4, at ", node_def.name());
|
||||
}
|
||||
const int num_channels = data_format == "NCHW" ? dims.d[0] : dims.d[2];
|
||||
const int h = data_format == "NCHW" ? dims.d[1] : dims.d[0];
|
||||
const int w = data_format == "NCHW" ? dims.d[2] : dims.d[1];
|
||||
// Get shuffle parameters.
|
||||
nvinfer1::Dims first_shuffle_shape;
|
||||
nvinfer1::Permutation transpose_perm;
|
||||
nvinfer1::Dims second_shuffle_shape;
|
||||
if (node_def.op() == "DepthToSpace") {
|
||||
if (num_channels % (block_size * block_size) != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of channels must be divisible by block_size*block_size, at ",
|
||||
node_def.name());
|
||||
}
|
||||
// First Reshape [C, H, W] - > [r, r, C/(r*r), H, W]
|
||||
first_shuffle_shape = {
|
||||
/*nbDims=*/5,
|
||||
/*d=*/{block_size, block_size, num_channels / (block_size * block_size),
|
||||
h, w}};
|
||||
// Transpose [r, r, C/(r*r), H, W] -> [C/(r*r), H, r, W, r]
|
||||
transpose_perm = {2, 3, 0, 4, 1};
|
||||
// Second Reshape [C/(r*r), H, r, W, r] -> [C/(r*r), H * r, W * r]
|
||||
second_shuffle_shape =
|
||||
nvinfer1::DimsCHW(num_channels / (block_size * block_size),
|
||||
h * block_size, w * block_size);
|
||||
} else if (node_def.op() == "SpaceToDepth") {
|
||||
if (h % block_size != 0 || w % block_size != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Width and height must be divisible by block_size, at ",
|
||||
node_def.name());
|
||||
}
|
||||
// First Reshape [C, H, W] -> [C, H/r, r, W/r, r]
|
||||
first_shuffle_shape = {/*nbDims=*/5,
|
||||
/*d=*/{num_channels, h / block_size, block_size,
|
||||
w / block_size, block_size}};
|
||||
// Transpose [C, H/r, r, W/r, r] -> [r, r, C, H/r, W/r]
|
||||
transpose_perm = {2, 4, 0, 1, 3};
|
||||
// Second Reshape [r, r, C, H/r, W/r] -> [C*r*r, H/r, W/r]
|
||||
second_shuffle_shape = nvinfer1::DimsCHW(
|
||||
num_channels * block_size * block_size, h / block_size, w / block_size);
|
||||
}
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
nvinfer1::IShuffleLayer* first_shuffle =
|
||||
params->converter->network()->addShuffle(*inputs.at(0).tensor());
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name());
|
||||
if (data_format == "NHWC") {
|
||||
first_shuffle->setFirstTranspose({2, 0, 1});
|
||||
}
|
||||
first_shuffle->setReshapeDimensions(first_shuffle_shape);
|
||||
first_shuffle->setSecondTranspose(transpose_perm);
|
||||
|
||||
nvinfer1::IShuffleLayer* second_shuffle =
|
||||
params->converter->network()->addShuffle(*first_shuffle->getOutput(0));
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name());
|
||||
second_shuffle->setReshapeDimensions(second_shuffle_shape);
|
||||
if (data_format == "NHWC") {
|
||||
second_shuffle->setSecondTranspose({1, 2, 0});
|
||||
}
|
||||
|
||||
params->converter->MarkQuantizationRangesAsInferrable(
|
||||
inputs.at(0).tensor(), first_shuffle->getOutput(0));
|
||||
params->converter->MarkQuantizationRangesAsInferrable(
|
||||
first_shuffle->getOutput(0), second_shuffle->getOutput(0));
|
||||
params->outputs->push_back(TRT_TensorOrWeights(second_shuffle->getOutput(0)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
|
||||
Status ConvertCombinedNMS(OpConverterParams* params) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -4416,6 +4514,7 @@ static void RegisterValidatableOpConverters(
|
||||
(*registration)["Const"] = ConvertConst;
|
||||
(*registration)["Conv2D"] = ConvertConv2D;
|
||||
(*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
|
||||
(*registration)["DepthToSpace"] = ConvertDepthSpaceShuffle;
|
||||
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
|
||||
(*registration)["ExpandDims"] = ConvertExpandDims;
|
||||
(*registration)["GatherV2"] = ConvertGather;
|
||||
@ -4430,6 +4529,7 @@ static void RegisterValidatableOpConverters(
|
||||
(*registration)["Slice"] = ConvertSlice;
|
||||
(*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed
|
||||
(*registration)["Softmax"] = ConvertSoftmax;
|
||||
(*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle;
|
||||
(*registration)["Split"] = ConvertSplit;
|
||||
(*registration)["Square"] = ConvertSquare;
|
||||
(*registration)["Squeeze"] = ConvertSqueeze;
|
||||
|
@ -212,6 +212,19 @@ std::vector<CType> InitTestVector(int size, CType start_value = CType(0)) {
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename InCType, typename OutCType>
|
||||
struct StaticCaster {
|
||||
OutCType operator()(InCType in) const { return static_cast<OutCType>(in); }
|
||||
};
|
||||
|
||||
template <typename InCType, typename OutCType>
|
||||
std::vector<OutCType> CastTestVector(const std::vector<InCType>& vals) {
|
||||
std::vector<OutCType> res(vals.size());
|
||||
std::transform(vals.begin(), vals.end(), res.begin(),
|
||||
StaticCaster<InCType, OutCType>());
|
||||
return res;
|
||||
}
|
||||
|
||||
// Fake ITensor implementation for testing purposes.
|
||||
class FakeITensor : public nvinfer1::ITensor {
|
||||
public:
|
||||
@ -721,19 +734,25 @@ TEST_F(ConverterTest, TransposeTensor) {
|
||||
ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions());
|
||||
}
|
||||
|
||||
void TestPrepareTensorForShape_Tensor(
|
||||
const std::vector<int>& tensor_dims, const std::vector<int>& reshape_dims,
|
||||
const std::vector<int>& expected_tensor_dims, Converter* converter,
|
||||
void TestPrepareTensorForShape(
|
||||
const std::vector<int>& input_dims, const std::vector<int>& reshape_dims,
|
||||
const std::vector<int>& expected_tensor_dims, bool input_is_tensor,
|
||||
Converter* converter, TrtWeightStore* weight_store,
|
||||
error::Code expected_code = error::OK,
|
||||
const char* expected_error_msg_substr = nullptr) {
|
||||
nvinfer1::ITensor* input_tensor = converter->network()->addInput(
|
||||
"", nvinfer1::DataType::kFLOAT, GetTestDims(tensor_dims));
|
||||
TRT_TensorOrWeights input;
|
||||
if (input_is_tensor) {
|
||||
input = TRT_TensorOrWeights(converter->network()->addInput(
|
||||
"", nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
|
||||
} else {
|
||||
input = TRT_TensorOrWeights(weight_store->GetTempWeights(
|
||||
nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
|
||||
}
|
||||
nvinfer1::ITensor* output_tensor = nullptr;
|
||||
|
||||
for (bool validation_only : {false, true}) {
|
||||
const Status status = converter->PrepareTensorForShape(
|
||||
TRT_TensorOrWeights(input_tensor), GetTestDims(reshape_dims),
|
||||
validation_only, &output_tensor);
|
||||
input, GetTestDims(reshape_dims), validation_only, &output_tensor);
|
||||
if (expected_code == error::OK) {
|
||||
TF_EXPECT_OK(status);
|
||||
if (validation_only) {
|
||||
@ -748,49 +767,45 @@ void TestPrepareTensorForShape_Tensor(
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
|
||||
// Shape size doesn't match.
|
||||
TEST_F(ConverterTest, PrepareTensorForShape) {
|
||||
for (bool input_is_tensor : {true, false}) {
|
||||
// Shape size doesn't match.
|
||||
Reset();
|
||||
TestPrepareTensorForShape({2, 3, 5}, {2, 3, 6}, {}, input_is_tensor,
|
||||
converter_.get(), weight_store_,
|
||||
error::INVALID_ARGUMENT, "Incompatible shapes");
|
||||
|
||||
// Regular shape.
|
||||
Reset();
|
||||
TestPrepareTensorForShape({2, 3, 5}, {10, 3}, {10, 3}, input_is_tensor,
|
||||
converter_.get(), weight_store_);
|
||||
|
||||
// Reshape to zero rank.
|
||||
Reset();
|
||||
TestPrepareTensorForShape({1, 1}, {}, {}, input_is_tensor, converter_.get(),
|
||||
weight_store_);
|
||||
}
|
||||
|
||||
// Tensor input with zero rank.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({2, 3, 5}, {2, 3, 6}, {}, converter_.get(),
|
||||
error::INVALID_ARGUMENT,
|
||||
"Incompatible shapes");
|
||||
TestPrepareTensorForShape({}, {1, 1}, {1, 1}, /*input_is_tensor=*/true,
|
||||
converter_.get(), weight_store_);
|
||||
|
||||
// TODO(aaroey): we should check the case where uninferred dimensions are
|
||||
// not an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
|
||||
|
||||
// Infer shape, ok.
|
||||
// Infer tensor shape, ok.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({2, 3, 5}, {-1, 2}, {15, 2},
|
||||
converter_.get());
|
||||
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
|
||||
/*input_is_tensor=*/true, converter_.get(),
|
||||
weight_store_);
|
||||
|
||||
// Regular shape.
|
||||
// Infer weight shape, should fail.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({2, 3, 5}, {10, 3}, {10, 3},
|
||||
converter_.get());
|
||||
|
||||
// Input with zero rank.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({}, {1, 1}, {1, 1}, converter_.get());
|
||||
|
||||
// Reshape to zero rank.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({1, 1}, {}, {}, converter_.get());
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
|
||||
TRT_ShapedWeights weights = weight_store_->GetTempWeights(
|
||||
nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
|
||||
nvinfer1::ITensor* output_tensor = nullptr;
|
||||
for (bool validation_only : {false, true}) {
|
||||
TF_EXPECT_OK(converter_->PrepareTensorForShape(
|
||||
TRT_TensorOrWeights(weights), GetTestDims({10, 3}), validation_only,
|
||||
&output_tensor));
|
||||
if (validation_only) {
|
||||
EXPECT_EQ(nullptr, output_tensor);
|
||||
} else {
|
||||
ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
|
||||
}
|
||||
}
|
||||
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
|
||||
/*input_is_tensor=*/false, converter_.get(),
|
||||
weight_store_, error::INVALID_ARGUMENT,
|
||||
"Shape is not fully defined");
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, MaybeUpdateBatchSize) {
|
||||
@ -4910,6 +4925,279 @@ TEST_F(OpConverterTest, ConvertArgMinMax) {
|
||||
// TestConvertArgMinMax<ops::ArgMax, DT_INT32>(this);
|
||||
}
|
||||
|
||||
// Get the NodeDef for DepthToSpace or SpaceToSpace.
|
||||
template <typename OpType>
|
||||
NodeDef GetDepthSpaceShuffleNodeDef(DataType dtype, int block_size,
|
||||
string data_format) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), dtype);
|
||||
auto attrs = OpType::DataFormat(data_format);
|
||||
auto shuffle = OpType(s.WithOpName("my_shuffle"), input, block_size, attrs);
|
||||
return shuffle.operation.node()->def();
|
||||
}
|
||||
|
||||
template <typename CType>
|
||||
struct DepthSpaceShuffleTestParams {
|
||||
std::vector<int> input_dims;
|
||||
std::vector<CType> input_value;
|
||||
int block_size;
|
||||
string data_format;
|
||||
std::vector<int> expected_output_dims;
|
||||
std::vector<CType> expected_output;
|
||||
};
|
||||
|
||||
template <typename OpType, DataType dtype, typename CType>
|
||||
void TestConvertDepthSpaceShuffle(
|
||||
OpConverterTest* test,
|
||||
const std::vector<DepthSpaceShuffleTestParams<CType>>& params) {
|
||||
for (int i = 0; i < params.size(); ++i) {
|
||||
test->Reset();
|
||||
|
||||
NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
|
||||
dtype, params[i].block_size, params[i].data_format);
|
||||
test->AddTestTensor("input", params[i].input_dims, 1,
|
||||
TfDataTypeToTrt(dtype));
|
||||
test->RunValidationAndConversion(node_def);
|
||||
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_shuffle", &output));
|
||||
EXPECT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
|
||||
output.tensor()->getDimensions());
|
||||
|
||||
DataVec input_data{{"input", test::AsTensor<CType>(params[i].input_value)}};
|
||||
DataVec output_data{{"my_shuffle", ConstructTensor<CType>(
|
||||
params[i].expected_output.size())}};
|
||||
test->BuildAndRun(
|
||||
input_data, &output_data,
|
||||
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(params[i].expected_output));
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void TestConvertDepthToSpace(OpConverterTest* test) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
const std::vector<CType> common_input = InitTestVector<CType>(16);
|
||||
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
|
||||
{
|
||||
/*input_shape=*/{4, 2, 2},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NCHW",
|
||||
/*expected_output_dims=*/{1, 4, 4},
|
||||
/*expected_output=*/
|
||||
CastTestVector<int, CType>(
|
||||
{0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15}),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{2, 2, 4},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NHWC",
|
||||
/*expected_output_dims=*/{4, 4, 1},
|
||||
/*expected_output=*/
|
||||
CastTestVector<int, CType>(
|
||||
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{16, 1, 1},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/4,
|
||||
/*data_format=*/"NCHW",
|
||||
/*expected_output_dims=*/{1, 4, 4},
|
||||
/*expected_output=*/InitTestVector<CType>(16),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{2, 2, 8},
|
||||
/*input_value=*/InitTestVector<CType>(32),
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NHWC",
|
||||
/*expected_output_dims=*/{4, 4, 2},
|
||||
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
|
||||
9, 10, 11, 4, 5,
|
||||
6, 7, 12, 13, 14,
|
||||
15, 16, 17, 18, 19,
|
||||
24, 25, 26, 27, 20,
|
||||
21, 22, 23, 28, 29,
|
||||
30, 31}),
|
||||
},
|
||||
};
|
||||
|
||||
TestConvertDepthSpaceShuffle<ops::DepthToSpace, dtype, CType>(test, params);
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertDepthToSpace) {
|
||||
{
|
||||
// Input list is empty, should fail.
|
||||
NodeDef node_def = MakeNodeDef("my_shuffle", "DepthToSpace", {});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"DepthToSpace got 0 inputs but expected 1, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Input is a weight, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
|
||||
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"The input \"input\" for DepthToSpace must be a "
|
||||
"tensor, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Input rank != 4
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
|
||||
AddTestTensor("input", {16, 32});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"The input to DepthToSpace must be rank 4, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Channels not divisible by block_size, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 3, "NCHW");
|
||||
AddTestTensor("input", {16, 32, 32});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Number of channels must be divisible by "
|
||||
"block_size*block_size, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Unsupported format, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(
|
||||
DT_FLOAT, 2, "NCHW_VECT_C");
|
||||
AddTestTensor("input", {16, 32, 32});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Data format NCHW_VECT_C is not supported, at my_shuffle");
|
||||
}
|
||||
|
||||
TestConvertDepthToSpace<DT_FLOAT>(this);
|
||||
TestConvertDepthToSpace<DT_HALF>(this);
|
||||
TestConvertDepthToSpace<DT_INT32>(this);
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void TestConvertSpaceToDepth(OpConverterTest* test) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
const std::vector<CType> common_input = InitTestVector<CType>(16);
|
||||
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
|
||||
{
|
||||
/*input_shape=*/{1, 4, 4},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NCHW",
|
||||
/*expected_output_dims=*/{4, 2, 2},
|
||||
/*expected_output=*/
|
||||
CastTestVector<int, CType>(
|
||||
{0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15}),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{4, 4, 1},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NHWC",
|
||||
/*expected_output_dims=*/{2, 2, 4},
|
||||
/*expected_output=*/
|
||||
CastTestVector<int, CType>(
|
||||
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{1, 4, 4},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/4,
|
||||
/*data_format=*/"NCHW",
|
||||
/*expected_output_dims=*/{16, 1, 1},
|
||||
/*expected_output=*/InitTestVector<CType>(16),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{4, 4, 2},
|
||||
/*input_value=*/InitTestVector<CType>(32),
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NHWC",
|
||||
/*expected_output_dims=*/{2, 2, 8},
|
||||
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
|
||||
9, 10, 11, 4, 5,
|
||||
6, 7, 12, 13, 14,
|
||||
15, 16, 17, 18, 19,
|
||||
24, 25, 26, 27, 20,
|
||||
21, 22, 23, 28, 29,
|
||||
30, 31}),
|
||||
},
|
||||
};
|
||||
|
||||
TestConvertDepthSpaceShuffle<ops::SpaceToDepth, dtype, CType>(test, params);
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertSpaceToDepth) {
|
||||
{
|
||||
// Input list is empty, should fail.
|
||||
NodeDef node_def = MakeNodeDef("my_shuffle", "SpaceToDepth", {});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"SpaceToDepth got 0 inputs but expected 1, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Input is a weight, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
|
||||
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"The input \"input\" for SpaceToDepth must be a "
|
||||
"tensor, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Input rank != 4
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
|
||||
AddTestTensor("input", {16, 32});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"The input to SpaceToDepth must be rank 4, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Width not divisble by block_size, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
|
||||
AddTestTensor("input", {16, 9, 32});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Width and height must be divisible by "
|
||||
"block_size, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Height not divisble by block_size, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
|
||||
AddTestTensor("input", {16, 32, 9});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Width and height must be divisible by "
|
||||
"block_size, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Unsupported format, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(
|
||||
DT_FLOAT, 2, "NCHW_VECT_C");
|
||||
AddTestTensor("input", {16, 32, 32});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Data format NCHW_VECT_C is not supported, at my_shuffle");
|
||||
}
|
||||
|
||||
TestConvertSpaceToDepth<DT_FLOAT>(this);
|
||||
TestConvertSpaceToDepth<DT_HALF>(this);
|
||||
TestConvertSpaceToDepth<DT_INT32>(this);
|
||||
}
|
||||
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||
@ -290,17 +291,17 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
|
||||
VLOG(1) << "Executing TRT calibration: " << name();
|
||||
helper->Ref();
|
||||
core::ScopedUnref sc(helper);
|
||||
auto res_mgr = ctx->resource_manager();
|
||||
TRTCalibrationResource* calib_res = nullptr;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
res_mgr->LookupOrCreate(
|
||||
"TF_TRT_Calibration", name(),
|
||||
ctx->resource_manager()->LookupOrCreate(
|
||||
"TF-TRT-Calibration", name(),
|
||||
reinterpret_cast<SerializableResourceBase**>(&calib_res),
|
||||
{[ctx, this](SerializableResourceBase** cr) -> Status {
|
||||
return this->AllocateCalibrationResources(ctx, cr);
|
||||
}}));
|
||||
core::ScopedUnref calib_sc(calib_res);
|
||||
int num_inputs = ctx->num_inputs();
|
||||
// TODO(laigd): need to check that input shape matches.
|
||||
// Pass input data to calibrator
|
||||
std::unordered_map<string, void*> input_data;
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
@ -522,10 +523,22 @@ EngineContext* TRTEngineOp::GetEngine(
|
||||
// TODO(tmorris): using first input to get batch size - is this reliable?
|
||||
const int batch_size = input_shapes[0].dim_size(0);
|
||||
|
||||
// Get engine cache
|
||||
// Canonicalize the op name by removing the scopes if any. This is mainly
|
||||
// because in TFv2, the function graph can be instantiated in various ways and
|
||||
// it'll insert scope names to the name of the TRTEngineOps, which will result
|
||||
// in many different engine caches if we use the instantiated op name
|
||||
// directly, but we still want all of them share the same cache (if they were
|
||||
// representing the same subgraph).
|
||||
absl::string_view resource_name = name();
|
||||
size_t last_slash = resource_name.find_last_of('/');
|
||||
if (last_slash != absl::string_view::npos) {
|
||||
resource_name.remove_prefix(last_slash + 1);
|
||||
}
|
||||
|
||||
// Get engine cache.
|
||||
TRTEngineCacheResource* cache_res = nullptr;
|
||||
auto status = ctx->resource_manager()->LookupOrCreate(
|
||||
"TRTEngineCache", name(), &cache_res,
|
||||
"TF-TRT-Engine-Cache", string(resource_name), &cache_res,
|
||||
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
|
||||
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
|
||||
return Status::OK();
|
||||
@ -632,12 +645,13 @@ EngineContext* TRTEngineOp::GetEngine(
|
||||
cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
|
||||
return &empty_context;
|
||||
}
|
||||
VLOG(1) << "Conversion is done";
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||
engine->createExecutionContext());
|
||||
cache.emplace(engine_input_shapes,
|
||||
absl::make_unique<EngineContext>(std::move(engine),
|
||||
std::move(exec_context)));
|
||||
VLOG(1) << "Added new engine to cache of " << name()
|
||||
<< ". Cache size: " << cache.size();
|
||||
}
|
||||
return cache.at(engine_input_shapes).get();
|
||||
}
|
||||
|
@ -459,7 +459,7 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
}
|
||||
LOG(INFO) << msg << "(For more information see "
|
||||
<< "https://docs.nvidia.com/deeplearning"
|
||||
<< "/dgx/integrate-tf-trt/index.html#support-ops).";
|
||||
<< "/dgx/tf-trt-user-guide/index.html#supported-ops).";
|
||||
|
||||
// The segmentation algorithm below visits nodes in reverse topological order
|
||||
// and attempts to merge nodes along output edges. That means that subgraphs
|
||||
|
74
tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc
Normal file
74
tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx,
|
||||
size_t capacity)
|
||||
: cache_(capacity) {
|
||||
auto device = ctx->device();
|
||||
auto alloc = device->GetAllocator(AllocatorAttributes());
|
||||
if (!alloc) {
|
||||
LOG(ERROR) << "Can't find device allocator for gpu device "
|
||||
<< device->name();
|
||||
allocator_ = nullptr;
|
||||
} else {
|
||||
allocator_.reset(new TRTDeviceAllocator(alloc));
|
||||
}
|
||||
}
|
||||
|
||||
TRTEngineCacheResource::~TRTEngineCacheResource() {
|
||||
VLOG(1) << "Destroying TRTEngineCacheResource...";
|
||||
}
|
||||
|
||||
string TRTEngineCacheResource::DebugString() const {
|
||||
std::stringstream oss;
|
||||
using std::dec;
|
||||
using std::endl;
|
||||
using std::hex;
|
||||
oss << "TRTEngineCacheResource: ";
|
||||
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
|
||||
oss << "LRUCache = " << hex << &cache_ << dec << endl;
|
||||
oss << "Containing " << cache_.size() << " entries: " << endl;
|
||||
for (const auto& item : cache_) {
|
||||
mutex_lock lock(item.second->mu);
|
||||
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
|
||||
<< "ICudaEngine: " << item.second->cuda_engine.get() << ", "
|
||||
<< "IExecutionContext: " << item.second->execution_context.get() << dec
|
||||
<< endl;
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -141,36 +141,11 @@ struct EngineContext {
|
||||
|
||||
class TRTEngineCacheResource : public ResourceBase {
|
||||
public:
|
||||
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity)
|
||||
: cache_(capacity) {
|
||||
auto device = ctx->device();
|
||||
auto alloc = device->GetAllocator(AllocatorAttributes());
|
||||
if (!alloc) {
|
||||
LOG(ERROR) << "Can't find device allocator for gpu device "
|
||||
<< device->name();
|
||||
allocator_ = nullptr;
|
||||
} else {
|
||||
allocator_.reset(new TRTDeviceAllocator(alloc));
|
||||
}
|
||||
}
|
||||
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity);
|
||||
|
||||
string DebugString() const override {
|
||||
std::stringstream oss;
|
||||
using std::dec;
|
||||
using std::endl;
|
||||
using std::hex;
|
||||
oss << "TRTEngineCacheResource: ";
|
||||
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
|
||||
oss << "LRUCache = " << hex << &cache_ << dec << endl;
|
||||
oss << "Containing " << cache_.size() << " entries: " << endl;
|
||||
for (const auto& item : cache_) {
|
||||
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
|
||||
<< "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", "
|
||||
<< "IExecutionContext: " << item.second.get()->execution_context.get()
|
||||
<< dec << endl;
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
~TRTEngineCacheResource() override;
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
// Keep device allocator for TRT.
|
||||
std::unique_ptr<TRTBaseAllocator> allocator_;
|
||||
|
@ -44,6 +44,7 @@ class BatchMatMulOp : public XlaOpKernel {
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("BatchMatMul"), BatchMatMulOp);
|
||||
REGISTER_XLA_OP(Name("BatchMatMulV2"), BatchMatMulOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
// XLA implementations of Categorical op.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
@ -140,8 +141,6 @@ class StatelessCategoricalOp : public CategoricalOp {
|
||||
xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type,
|
||||
XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp seed = ctx->Input(2);
|
||||
auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
if (uniform_shape.element_type() == xla::BF16) {
|
||||
@ -150,8 +149,8 @@ class StatelessCategoricalOp : public CategoricalOp {
|
||||
// We want a number in (0, 1) rather than [0, 1) or (0, 1]:
|
||||
// * log(-log(0)) is ∞.
|
||||
// * log(-log(1)) is -∞.
|
||||
auto uniforms = xla::StatelessRngUniform(
|
||||
{seed0, seed1}, uniform_shape,
|
||||
xla::XlaOp uniforms = StatelessRngUniform(
|
||||
seed, uniform_shape,
|
||||
xla::MinPositiveNormalValue(builder, uniform_shape.element_type()),
|
||||
xla::One(builder, uniform_shape.element_type()));
|
||||
return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type);
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// XLA-specific Ops for 2D convolution.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
@ -293,10 +294,9 @@ xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
|
||||
return attrs;
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
|
||||
xla::XlaOp conv_input,
|
||||
xla::XlaOp filter,
|
||||
const ConvOpAttrs& attrs) {
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
|
||||
StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter,
|
||||
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
|
||||
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
|
||||
|
||||
auto* builder = conv_input.builder();
|
||||
@ -377,12 +377,14 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
|
||||
return xla::ConvGeneralDilated(
|
||||
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dims,
|
||||
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count);
|
||||
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count,
|
||||
/*batch_group_count=*/1, precision_config);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config) {
|
||||
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
|
||||
|
||||
int num_dims = attrs.num_spatial_dims + 2;
|
||||
@ -456,13 +458,14 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
/*feature_group_count=*/
|
||||
attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
|
||||
filter_shape.dimensions(attrs.num_spatial_dims + 1)
|
||||
: feature_group_count);
|
||||
: feature_group_count,
|
||||
/*batch_group_count=*/1, precision_config);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
StringPiece type_string, xla::XlaOp activations,
|
||||
const xla::Shape& filter_shape, xla::XlaOp gradients,
|
||||
const ConvOpAttrs& attrs) {
|
||||
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
|
||||
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
|
||||
|
||||
auto* builder = activations.builder();
|
||||
@ -612,7 +615,8 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
|
||||
rhs_dilation, dnums,
|
||||
/*feature_group_count=*/feature_group_count,
|
||||
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1);
|
||||
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1,
|
||||
precision_config);
|
||||
|
||||
if (!use_batch_group_count && attrs.depthwise) {
|
||||
filter_backprop = ContractFilterForDepthwiseBackprop(
|
||||
|
@ -53,17 +53,19 @@ struct ConvOpAttrs {
|
||||
|
||||
// Creates a new XLA forward or backward convolution with the given inputs and
|
||||
// attributes.
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
|
||||
xla::XlaOp conv_input,
|
||||
xla::XlaOp filter,
|
||||
const ConvOpAttrs& attrs);
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
|
||||
StringPiece type_string, xla::XlaOp conv_input, xla::XlaOp filter,
|
||||
const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config = nullptr);
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config = nullptr);
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
StringPiece type_string, xla::XlaOp activations,
|
||||
const xla::Shape& filter_shape, xla::XlaOp gradients,
|
||||
const ConvOpAttrs& attrs);
|
||||
const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config = nullptr);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -22,6 +22,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Returns a tensor containing 'shape' random values uniformly distributed in
|
||||
// the range [minval, maxval). The raw random bits are generated by the given
|
||||
// `bit_generator` and converted to the requested data type and range. This
|
||||
// routine requires 2 32-bit integer seeds and currently only supports 'shape's
|
||||
// of type F32, S32 and S64.
|
||||
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
|
||||
xla::XlaOp minval, xla::XlaOp maxval);
|
||||
|
||||
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
|
||||
// It masks the last 16 bit. With normal rounding, values near "maxval" would be
|
||||
|
@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel {
|
||||
TensorShape shape;
|
||||
int64 product = 1;
|
||||
int unknown_index = -1;
|
||||
bool shape_has_zero_dim = false;
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
const int32 size = shape_input[d];
|
||||
if (size == -1) {
|
||||
@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel {
|
||||
unknown_index, " and ", d));
|
||||
unknown_index = d;
|
||||
shape.AddDim(1);
|
||||
} else if (size == 0) {
|
||||
// We don't include zero-sized dimension in product, so that we can
|
||||
// still calculate number of elements for non-zero-sized dimensions and
|
||||
// therefore infer their shapes.
|
||||
shape.AddDim(size);
|
||||
shape_has_zero_dim = true;
|
||||
} else {
|
||||
OP_REQUIRES(ctx, size >= 0,
|
||||
errors::InvalidArgument(
|
||||
@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel {
|
||||
}
|
||||
}
|
||||
if (unknown_index != -1) {
|
||||
OP_REQUIRES(
|
||||
ctx, product > 0,
|
||||
errors::InvalidArgument("Reshape cannot infer the missing input size "
|
||||
"for an empty tensor unless all specified "
|
||||
"input sizes are non-zero"));
|
||||
const int64 missing = input_shape.num_elements() / product;
|
||||
OP_REQUIRES(
|
||||
ctx, product * missing == input_shape.num_elements(),
|
||||
errors::InvalidArgument(
|
||||
"Input to reshape is a tensor with ", input_shape.num_elements(),
|
||||
" values, but the requested shape requires a multiple of ",
|
||||
product));
|
||||
int64 input_num_elements = 1;
|
||||
bool input_has_zero_dim = false;
|
||||
for (int dim = 0; dim < input_shape.dims(); dim++) {
|
||||
// For zero dimension, we don't count it into `input_num_elements`
|
||||
// unless `sizes` has no zero dimension, so we are still able to
|
||||
// infer shapes for other dimensions.
|
||||
if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) {
|
||||
input_num_elements *= input_shape.dim_size(dim);
|
||||
} else {
|
||||
input_has_zero_dim = true;
|
||||
}
|
||||
}
|
||||
|
||||
const int64 missing = input_num_elements / product;
|
||||
if (!input_has_zero_dim) {
|
||||
OP_REQUIRES(
|
||||
ctx, product * missing == input_num_elements,
|
||||
errors::InvalidArgument(
|
||||
"Input to reshape is a tensor with ", input_num_elements,
|
||||
" values, but the requested shape requires a multiple of ",
|
||||
product));
|
||||
}
|
||||
shape.set_dim(unknown_index, missing);
|
||||
}
|
||||
OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),
|
||||
|
@ -35,127 +35,50 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter(
|
||||
xla::XlaOp counter, const int64 size) {
|
||||
auto builder = counter.builder();
|
||||
auto input_u64 = Iota(builder, xla::U64, size);
|
||||
input_u64 = input_u64 + counter;
|
||||
counter = counter + xla::ConstantR0<uint64>(builder, size);
|
||||
return std::make_pair(xla::Uint64ToUint32s(input_u64), counter);
|
||||
}
|
||||
|
||||
// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too
|
||||
// wastefully, only able to generate 2^32*2 int32 numbers for each key, while
|
||||
// the real capacity is 2^64*2. Counter-space efficiency is important for
|
||||
// stateful ops, hence the following 2 new functions.
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
auto builder = key.builder();
|
||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = xla::CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
auto inputs_counter = GetInputsFromCounter(counter, half_size);
|
||||
auto inputs = inputs_counter.first;
|
||||
counter = inputs_counter.second;
|
||||
auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key));
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
auto result = ConcatInDim(builder, outputs, 0);
|
||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||
counter);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||
auto inputs_counter = GetInputsFromCounter(counter, size);
|
||||
auto inputs = inputs_counter.first;
|
||||
counter = inputs_counter.second;
|
||||
auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
auto result = Uint32sToUint64(outputs);
|
||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||
counter);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
|
||||
xla::XlaOp counter,
|
||||
const xla::Shape& shape,
|
||||
xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
auto builder = key.builder();
|
||||
xla::RngOutput StatefulRngUniform(xla::XlaOp key, xla::XlaOp initial_state,
|
||||
const xla::Shape& shape, xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case xla::F32: {
|
||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval),
|
||||
counter);
|
||||
}
|
||||
case xla::U32: // fall through
|
||||
case xla::S32: {
|
||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(
|
||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32),
|
||||
counter);
|
||||
}
|
||||
case xla::U64: // fall through
|
||||
case xla::S64: {
|
||||
auto bits_counter = StatefulRngUniformU64(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(
|
||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64),
|
||||
counter);
|
||||
}
|
||||
case xla::F32:
|
||||
return xla::UniformF32Distribution(
|
||||
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
|
||||
case xla::U32:
|
||||
case xla::S32:
|
||||
case xla::U64:
|
||||
case xla::S64:
|
||||
return UniformIntDistribution(
|
||||
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
|
||||
default:
|
||||
return std::make_pair(
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by "
|
||||
"StatefulRngUniform; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter);
|
||||
return {key.builder()->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by "
|
||||
"StatefulRngUniform; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
initial_state};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename A, typename B, typename A2>
|
||||
std::pair<A2, B> map_first(std::function<A2(A)> f, std::pair<A, B> p) {
|
||||
return std::make_pair(f(p.first), p.second);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformFullInt(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
xla::RngOutput StatefulRngUniformFullInt(xla::XlaOp key,
|
||||
xla::XlaOp initial_state,
|
||||
const xla::Shape& shape) {
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
xla::RngOutput output = xla::ThreeFryBitGenerator(key, initial_state, shape);
|
||||
switch (type) {
|
||||
case xla::U32:
|
||||
return StatefulRngUniformU32(key, counter, shape);
|
||||
case xla::S32: {
|
||||
// Needs explicit function type because of type-inference failure.
|
||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||
return BitcastConvertType(x, xla::S32);
|
||||
};
|
||||
return map_first(f, StatefulRngUniformU32(key, counter, shape));
|
||||
}
|
||||
case xla::U64:
|
||||
return StatefulRngUniformU64(key, counter, shape);
|
||||
case xla::S64: {
|
||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||
return BitcastConvertType(x, xla::S64);
|
||||
};
|
||||
return map_first(f, StatefulRngUniformU64(key, counter, shape));
|
||||
}
|
||||
return output;
|
||||
case xla::S32:
|
||||
case xla::S64:
|
||||
output.value = BitcastConvertType(output.value, type);
|
||||
return output;
|
||||
default:
|
||||
auto builder = key.builder();
|
||||
return std::make_pair(
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
return {
|
||||
key.builder()->ReportError(xla::Unimplemented(
|
||||
"Types other than U32, S32, U64 and S64 are not implemented by "
|
||||
"StatefulRngUniformFullInt; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter);
|
||||
initial_state};
|
||||
}
|
||||
}
|
||||
|
||||
@ -177,15 +100,15 @@ xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
||||
0);
|
||||
}
|
||||
|
||||
using sampler_return_type = xla::StatusOr<std::pair<xla::XlaOp, xla::XlaOp>>;
|
||||
using SamplerReturnType = xla::StatusOr<xla::RngOutput>;
|
||||
|
||||
// A helper function containing the common part of several kernels below.
|
||||
// Precondition: 'algorithm' and 'shape' are compile-time constants.
|
||||
Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
||||
int alg_input_idx, int shape_input_idx,
|
||||
std::function<sampler_return_type(xla::XlaOp, xla::XlaOp,
|
||||
TensorShape)> const&
|
||||
sample_with_threefry) {
|
||||
Status CompileImpl(
|
||||
XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
|
||||
int shape_input_idx,
|
||||
std::function<SamplerReturnType(xla::XlaOp, xla::XlaOp, TensorShape)> const&
|
||||
sampler) {
|
||||
auto alg_shape = ctx->InputShape(alg_input_idx);
|
||||
if (alg_shape.dims() != 0) {
|
||||
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||
@ -215,24 +138,22 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
|
||||
|
||||
static constexpr int COUNTER_SIZE = 1;
|
||||
auto counter = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64);
|
||||
static constexpr int kStateSize = 1;
|
||||
auto state = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
|
||||
auto key = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}),
|
||||
{}),
|
||||
xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
|
||||
xla::U64);
|
||||
|
||||
auto status_or_value = sample_with_threefry(counter, key, shape);
|
||||
auto status_or_value = sampler(state, key, shape);
|
||||
if (!status_or_value.ok()) {
|
||||
return status_or_value.status();
|
||||
}
|
||||
auto output_counter = status_or_value.ConsumeValueOrDie();
|
||||
auto output = output_counter.first;
|
||||
counter = output_counter.second;
|
||||
ctx->SetOutput(0, output);
|
||||
auto builder = ctx->builder();
|
||||
var = ConcatScalars(builder, {counter, key});
|
||||
xla::RngOutput value_state = status_or_value.ConsumeValueOrDie();
|
||||
state = value_state.state;
|
||||
ctx->SetOutput(0, value_state.value);
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
var = ConcatScalars(builder, {state, key});
|
||||
xla::PrimitiveType state_element_type;
|
||||
TF_RETURN_IF_ERROR(
|
||||
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
||||
@ -252,23 +173,22 @@ class StatefulUniformOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry = [builder, this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
auto sampler = [builder, this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::RngOutput uniform_state = StatefulRngUniform(
|
||||
key, state, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
xla::XlaOp uniform = uniform_state.value;
|
||||
state = uniform_state.state;
|
||||
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
||||
return {{uniform, counter}};
|
||||
return {{uniform, state}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
/*shape_input_idx=*/2, sampler));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -293,30 +213,20 @@ class StatefulStandardNormalOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry =
|
||||
auto sampler =
|
||||
// Needs explicit lambda return type because it fails to be inferred.
|
||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
[this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape,
|
||||
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
// Convert uniform distribution to normal distribution by computing
|
||||
// sqrt(2) * erfinv(x)
|
||||
auto normal =
|
||||
xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
||||
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
||||
return {{normal, counter}};
|
||||
xla::RngOutput value_state = xla::NormalF32Distribution(
|
||||
key, state, xla::ThreeFryBitGenerator, xla_shape);
|
||||
xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
|
||||
return {{normal, value_state.state}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
/*shape_input_idx=*/2, sampler));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -341,27 +251,27 @@ class StatefulTruncatedNormalOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry =
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
auto sampler =
|
||||
// Needs explicit lambda return type because it fails to be inferred.
|
||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
[builder, this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape,
|
||||
xla::RngOutput uniform_result = StatefulRngUniform(
|
||||
key, state, xla_shape,
|
||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||
xla::One(builder, xla_shape.element_type()));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
xla::XlaOp uniform = uniform_result.value;
|
||||
state = uniform_result.state;
|
||||
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
||||
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
|
||||
return {{truncated_normal, counter}};
|
||||
return {{truncated_normal, state}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
/*shape_input_idx=*/2, sampler));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -388,11 +298,11 @@ class StatefulUniformIntOp : public XlaOpKernel {
|
||||
xla::XlaOp minval = ctx->Input(3);
|
||||
xla::XlaOp maxval = ctx->Input(4);
|
||||
auto sample_with_threefry = [minval, maxval, this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
return StatefulRngUniform(key, counter, xla_shape, minval, maxval);
|
||||
return StatefulRngUniform(key, state, xla_shape, minval, maxval);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
@ -420,12 +330,11 @@ class StatefulUniformFullIntOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto sample_with_threefry = [this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
auto sample_with_threefry = [this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
return StatefulRngUniformFullInt(key, counter, xla_shape);
|
||||
return StatefulRngUniformFullInt(key, state, xla_shape);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
|
@ -36,8 +36,8 @@ namespace tensorflow {
|
||||
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||
if (dtype == DT_BFLOAT16) {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
auto output = xla::BitcastConvertType(input, xla::U32) &
|
||||
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
|
||||
xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
|
||||
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
|
||||
return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
|
||||
xla::BF16);
|
||||
} else {
|
||||
@ -45,22 +45,36 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
|
||||
// Convert uniform distribution to normal distribution by computing
|
||||
// sqrt(2) * erfinv(x)
|
||||
return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
||||
}
|
||||
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
|
||||
xla::XlaOp minval, xla::XlaOp maxval) {
|
||||
xla::XlaBuilder* builder = seeds.builder();
|
||||
|
||||
// A wrapper of xla::StatelessRngUniform. Returns an op that produces random
|
||||
// values with uniform distribution in the range [minval, maxval) for the given
|
||||
// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
|
||||
// S64 are implemented.
|
||||
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused,
|
||||
xla::XlaOp seed, xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval);
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case xla::F32:
|
||||
return xla::UniformF32Distribution(key, initial_state,
|
||||
xla::ThreeFryBitGenerator, minval,
|
||||
maxval, shape)
|
||||
.value;
|
||||
case xla::S32: // fall through
|
||||
case xla::S64:
|
||||
return UniformIntDistribution(key, initial_state,
|
||||
xla::ThreeFryBitGenerator, minval, maxval,
|
||||
shape)
|
||||
.value;
|
||||
break;
|
||||
default:
|
||||
return builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, S32 and S64 are not implemented by "
|
||||
"StatelessRngUniform; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type)));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -86,8 +100,8 @@ class StatelessRandomUniformOp : public XlaOpKernel {
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
||||
xla_shape, dtype_, seed, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::XlaOp uniform = StatelessRngUniform(
|
||||
seed, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
||||
ctx->SetOutput(0, uniform);
|
||||
@ -136,8 +150,8 @@ class StatelessRandomUniformIntOp : public XlaOpKernel {
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
xla::XlaOp uniform =
|
||||
StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval);
|
||||
xla::XlaOp uniform = StatelessRngUniform(seed, xla_shape, minval, maxval);
|
||||
|
||||
ctx->SetOutput(0, uniform);
|
||||
}
|
||||
|
||||
@ -170,14 +184,20 @@ class StatelessRandomNormalOp : public XlaOpKernel {
|
||||
errors::InvalidArgument("seed must have shape [2], not ",
|
||||
seed_shape.DebugString()));
|
||||
xla::XlaOp seed = ctx->Input(1);
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
||||
xla_shape, dtype_, seed,
|
||||
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform);
|
||||
|
||||
xla::XlaBuilder* builder = seed.builder();
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp normal =
|
||||
xla::NormalF32Distribution(key, initial_state,
|
||||
xla::ThreeFryBitGenerator, xla_shape)
|
||||
.value;
|
||||
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
||||
ctx->SetOutput(0, normal);
|
||||
}
|
||||
@ -215,8 +235,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
||||
xla_shape, dtype_, seed,
|
||||
xla::XlaOp uniform = StatelessRngUniform(
|
||||
seed, xla_shape,
|
||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||
xla::One(builder, xla_shape.element_type()));
|
||||
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
||||
|
@ -165,7 +165,7 @@ Status RewriteAndPruneGraph(
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
||||
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
|
||||
PruneForReverseReachability(graph, retval_nodes);
|
||||
PruneForReverseReachability(graph, std::move(retval_nodes));
|
||||
FixupSourceAndSinkEdges(graph);
|
||||
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
|
||||
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
|
||||
|
@ -287,6 +287,7 @@ tf_cc_test(
|
||||
":xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/hash:hash_testing",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -32,8 +32,12 @@ XlaOp RotateLeftU32(XlaOp v, int distance) {
|
||||
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
// The internal state of the Three Fry implementation.
|
||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||
|
||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
||||
XlaBuilder* builder = input[0].builder();
|
||||
key[0] = BitcastConvertType(key[0], U32);
|
||||
@ -104,56 +108,68 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// Returns the inputs with unique counter values for ThreeFry2x32.
|
||||
ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) {
|
||||
ThreeFry2x32State inputs;
|
||||
inputs[0] = Iota(builder, U32, size);
|
||||
inputs[1] = inputs[0] + ConstantR0<uint32>(builder, size);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
|
||||
XlaBuilder* builder = key[0].builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
ThreeFry2x32State inputs = GetInputs(half_size, builder);
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
auto result = ConcatInDim(builder, outputs, 0);
|
||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||
}
|
||||
|
||||
// Converts a uint64 to two uint32s.
|
||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
|
||||
auto builder = u64.builder();
|
||||
auto const32 = ConstantR0WithType(builder, U64, 32);
|
||||
auto fst = ConvertElementType(u64, U32);
|
||||
auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
||||
XlaBuilder* builder = u64.builder();
|
||||
XlaOp const32 = ConstantR0WithType(builder, U64, 32);
|
||||
XlaOp fst = ConvertElementType(u64, U32);
|
||||
XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
||||
return {fst, snd};
|
||||
}
|
||||
|
||||
// Converts two uint32s to a uint64.
|
||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
|
||||
auto builder = u32s[0].builder();
|
||||
XlaBuilder* builder = u32s[0].builder();
|
||||
return ConvertElementType(u32s[0], U64) |
|
||||
ShiftLeft(ConvertElementType(u32s[1], U64),
|
||||
ConstantR0WithType(builder, U64, 32));
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
|
||||
XlaBuilder* builder = key[0].builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
ThreeFry2x32State inputs = GetInputs(size, builder);
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
||||
// low 32 bit: outputs[0], high 32 bit: outputs[1]
|
||||
auto result = Uint32sToUint64(outputs);
|
||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||
// Given the initial state and the request number of random numbers to be
|
||||
// generated, returns the input for the random number generator and a new state.
|
||||
std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
|
||||
XlaOp initial_state, const int64 size) {
|
||||
XlaBuilder* builder = initial_state.builder();
|
||||
XlaOp input_u64 = Iota(builder, U64, size);
|
||||
input_u64 = input_u64 + initial_state;
|
||||
XlaOp new_state = initial_state + ConstantR0<uint64>(builder, size);
|
||||
return std::make_pair(Uint64ToUint32s(input_u64), new_state);
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
// Generates random 32bits with the given shape using the Three Fry
|
||||
// implementation. Returns the random bits and the new state.
|
||||
RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
|
||||
XlaBuilder* builder = key.builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||
GetThreeFryInputsAndUpdatedState(initial_state, half_size);
|
||||
ThreeFry2x32State inputs = inputs_state.first;
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
XlaOp result = ConcatInDim(builder, outputs, 0);
|
||||
return {Reshape(result, AsInt64Slice(shape.dimensions())),
|
||||
inputs_state.second};
|
||||
}
|
||||
|
||||
// Generates random 64bits with the given shape using the Three Fry
|
||||
// implementation. Returns the random bits and the new state.
|
||||
RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||
GetThreeFryInputsAndUpdatedState(initial_state, size);
|
||||
ThreeFry2x32State inputs = inputs_state.first;
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
XlaOp result = Uint32sToUint64(outputs);
|
||||
return {Reshape(result, AsInt64Slice(shape.dimensions())),
|
||||
inputs_state.second};
|
||||
}
|
||||
|
||||
XlaOp ConvertRandomBitsToUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
// Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
|
||||
// forces the random bits into the mantissa.
|
||||
constexpr int kFloatBits = 32;
|
||||
@ -161,50 +177,95 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
bits = ShiftRightLogical(
|
||||
bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
|
||||
ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
|
||||
auto floats = BitcastConvertType(bits, F32);
|
||||
XlaOp values = BitcastConvertType(bits, F32);
|
||||
|
||||
// We have a floating point number in the range [1.0, 2.0).
|
||||
// Subtract 1.0f to shift to the range [0.0, 1.0)
|
||||
floats = floats - ConstantR0<float>(builder, 1.0f);
|
||||
values = values - ConstantR0<float>(builder, 1.0f);
|
||||
// Multiply and add to shift to the range [minval, maxval).
|
||||
return floats * (maxval - minval) + minval;
|
||||
return values * (maxval - minval) + minval;
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type, PrimitiveType unsigned_type) {
|
||||
XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type,
|
||||
PrimitiveType unsigned_type) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
auto range = BitcastConvertType(maxval, unsigned_type) -
|
||||
BitcastConvertType(minval, unsigned_type);
|
||||
auto dist = Rem(bits, range);
|
||||
auto dist_div_2 =
|
||||
XlaOp range = BitcastConvertType(maxval, unsigned_type) -
|
||||
BitcastConvertType(minval, unsigned_type);
|
||||
XlaOp dist = Rem(bits, range);
|
||||
XlaOp dist_div_2 =
|
||||
ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
|
||||
|
||||
return minval + BitcastConvertType(dist_div_2, type) +
|
||||
BitcastConvertType(dist - dist_div_2, type);
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||
XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = seeds[0].builder();
|
||||
XlaOp UniformToNormalUsingSqrtErfInv(XlaOp uniform) {
|
||||
return ScalarLike(uniform, std::sqrt(2.0)) * ErfInv(uniform);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const Shape& shape) {
|
||||
PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case F32: {
|
||||
auto bits = StatelessRngUniformU32(seeds, shape);
|
||||
return StatelessRngUniformF32(bits, minval, maxval);
|
||||
}
|
||||
case S32: {
|
||||
auto bits = StatelessRngUniformU32(seeds, shape);
|
||||
return StatelessRngUniformInt(bits, minval, maxval, type, U32);
|
||||
}
|
||||
case S64: {
|
||||
auto bits = StatelessRngUniformU64(seeds, shape);
|
||||
return StatelessRngUniformInt(bits, minval, maxval, type, U64);
|
||||
}
|
||||
case F32:
|
||||
case U32:
|
||||
case S32:
|
||||
return ThreeFryRngBit32(key, initial_state, shape);
|
||||
case U64:
|
||||
case S64:
|
||||
return ThreeFryRngBit64(key, initial_state, shape);
|
||||
default:
|
||||
return builder->ReportError(Unimplemented(
|
||||
"Types other than F32, S32 and S64 are not implemented by "
|
||||
"StatelessRngUniform."));
|
||||
return {key.builder()->ReportError(Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by ThreeFryBitGenerator; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
initial_state};
|
||||
}
|
||||
}
|
||||
|
||||
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const Shape& shape) {
|
||||
DCHECK_EQ(shape.element_type(), F32);
|
||||
RngOutput bits_state = bit_generator(key, initial_state, shape);
|
||||
XlaOp bits = bits_state.value;
|
||||
XlaOp new_state = bits_state.state;
|
||||
return {ConvertRandomBitsToUniformF32(bits, minval, maxval), new_state};
|
||||
}
|
||||
|
||||
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const Shape& shape) {
|
||||
RngOutput bits_state = bit_generator(key, initial_state, shape);
|
||||
XlaOp bits = bits_state.value;
|
||||
XlaOp new_state = bits_state.state;
|
||||
PrimitiveType type = shape.element_type();
|
||||
PrimitiveType unsigned_type;
|
||||
if (type == U32 || type == S32) {
|
||||
unsigned_type = U32;
|
||||
} else {
|
||||
DCHECK(type == U64 || type == S64);
|
||||
unsigned_type = U64;
|
||||
}
|
||||
return {
|
||||
ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
|
||||
new_state};
|
||||
}
|
||||
|
||||
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator,
|
||||
const Shape& shape) {
|
||||
DCHECK_EQ(shape.element_type(), F32);
|
||||
XlaBuilder* builder = key.builder();
|
||||
RngOutput bits_state = UniformF32Distribution(
|
||||
key, initial_state, bit_generator,
|
||||
ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||
ConstantR0<float>(builder, 1.0), shape);
|
||||
XlaOp normal = UniformToNormalUsingSqrtErfInv(bits_state.value);
|
||||
return {normal, bits_state.state};
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -23,37 +23,52 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Records the bits and state generated by a random number generator.
|
||||
struct RngOutput {
|
||||
XlaOp value;
|
||||
XlaOp state;
|
||||
};
|
||||
|
||||
// A BitGenerator returns random bits and updated random bit generator state.
|
||||
//
|
||||
// key: is a value input to a random number generator that can affect the
|
||||
// sequence of number it will generate. A random number generator constructs
|
||||
// its seed using the key and the initial state. The tf2xla bridge passes the
|
||||
// seed operand of a tensorflow random operation as a key to the random bit
|
||||
// generator, for example.
|
||||
// initial_state: initial_state is the initial state of the current random
|
||||
// number generation. It could be 0 for a stateless random operation, and
|
||||
// the returned state from a previous execution for a stateful random
|
||||
// operation.
|
||||
// shape: the shape of the random bits.
|
||||
using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state,
|
||||
const xla::Shape& shape)>;
|
||||
|
||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key);
|
||||
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const xla::Shape& shape);
|
||||
|
||||
// Returns a tensor containing 'shape' random values uniformly distributed in
|
||||
// the range [minval, maxval). Requires 2 32-bit integer seeds.
|
||||
// Currently only 'shape's of type F32, S32 and S64 are implemented.
|
||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||
XlaOp minval, XlaOp maxval);
|
||||
// Uses the given bit generator to generate random bits and then converts the
|
||||
// random bits to random numbers of uniform distribution in the given range.
|
||||
// Returns the random numbers and the state of the random number generator.
|
||||
// This function is for shape with float element type.
|
||||
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const xla::Shape& shape);
|
||||
|
||||
// Converts a 32-bit (signed or unsigned) integer random number `bits` into a
|
||||
// float32 in the range [minval, maxval).
|
||||
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
|
||||
// Similar to UniformF32Distribution but for shape with integer element types.
|
||||
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const xla::Shape& shape);
|
||||
|
||||
// Converts an integer random number 'bits' of type 'type' to a random number
|
||||
// in the range [minval, maxval), of the same type. 'unsigned_type' is the
|
||||
// unsigned version of 'type' (could be the same) with the same bit width.
|
||||
// The algorithm is the same one that TF uses right now, but it's
|
||||
// uniform only when maxval - minval is a divisor of the range that bits is
|
||||
// generated from.
|
||||
// TODO(b/72573764): Generate real uniform integer distribution.
|
||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type, PrimitiveType unsigned_type);
|
||||
|
||||
// The following 2 functions, for converting between one uint64 and two uint32s,
|
||||
// use the contract "lower 32 bits for the first uint32, higher 32 bits for the
|
||||
// second".
|
||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
|
||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
|
||||
// Uses the given bit generator to generate random bits and then converts the
|
||||
// random bits to random numbers of normal distribution.
|
||||
// Returns the random numbers and the state of the random number generator.
|
||||
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator,
|
||||
const xla::Shape& shape);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -69,6 +69,11 @@ class Tile {
|
||||
// combined with the next minor dimension before tiling is applied.
|
||||
static constexpr int64 kCombineDimension = std::numeric_limits<int64>::min();
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Tile& t) {
|
||||
return H::combine(std::move(h), t.dimensions_);
|
||||
}
|
||||
|
||||
private:
|
||||
// The bounds of the tile.
|
||||
std::vector<int64> dimensions_;
|
||||
@ -212,6 +217,13 @@ class Layout {
|
||||
element_size_in_bits_ = 0;
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Layout& l) {
|
||||
return H::combine(std::move(h), l.format_, l.minor_to_major_,
|
||||
l.max_sparse_elements_, l.tiles_,
|
||||
l.element_size_in_bits_);
|
||||
}
|
||||
|
||||
private:
|
||||
// The format of this layout.
|
||||
Format format_ = INVALID_FORMAT;
|
||||
|
@ -109,6 +109,7 @@ tf_pybind_extension(
|
||||
":xrt",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@pybind11",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
@ -81,9 +82,12 @@ StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
|
||||
|
||||
DeviceMemoryAllocator* allocator = client->backend().memory_allocator();
|
||||
TransferManager* transfer_manager = client->backend().transfer_manager();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, transfer_manager->ChooseCompactLayoutForShape(tree.shape));
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
tree.shape, allocator, device_ordinal));
|
||||
shape, allocator, device_ordinal));
|
||||
TF_ASSIGN_OR_RETURN(auto stream,
|
||||
client->mutable_backend()->BorrowStream(device_ordinal));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -91,7 +95,7 @@ StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
|
||||
|
||||
auto it = tree.leaves.begin();
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(tree.shape)) {
|
||||
ShapeUtil::GetLeafShapes(shape)) {
|
||||
TF_RET_CHECK(it != tree.leaves.end());
|
||||
ShapedBuffer leaf(
|
||||
indexed_shape.shape,
|
||||
@ -224,10 +228,7 @@ StatusOr<LocalShapedBuffer> LocalExecutableWrapper::Execute(
|
||||
result_buffer_status = executable_->Run(argument_buffers, options);
|
||||
|
||||
if (!result_buffer_status.ok()) {
|
||||
return InternalError(
|
||||
"Failed running replica 0 (other replicas may have failed as well): "
|
||||
"%s.",
|
||||
result_buffer_status.status().ToString());
|
||||
return result_buffer_status.status();
|
||||
}
|
||||
return LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
|
||||
client_);
|
||||
@ -298,10 +299,12 @@ LocalExecutableWrapper::ExecutePerReplica(
|
||||
for (int replica = 0; replica < num_replicas(); ++replica) {
|
||||
auto& statusor = results[replica];
|
||||
if (!statusor.ok()) {
|
||||
return InternalError(
|
||||
"Failed running replica %d (other replicas may have failed as well): "
|
||||
"%s.",
|
||||
replica, statusor.status().ToString());
|
||||
return AppendStatus(
|
||||
statusor.status(),
|
||||
absl::StrFormat(
|
||||
"while running replica %d of a replicated computation (other "
|
||||
"replicas may have failed as well).",
|
||||
replica));
|
||||
}
|
||||
wrapped_results[replica] =
|
||||
LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
|
||||
@ -346,23 +349,53 @@ StatusOr<std::string> GetComputationHloDotGraph(
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<LocalExecutableWrapper>>
|
||||
LocalExecutableWrapper::Compile(const XlaComputation& computation,
|
||||
const std::vector<Shape>& argument_shapes,
|
||||
std::vector<Shape> argument_layouts,
|
||||
const ExecutableBuildOptions* build_options,
|
||||
LocalClient* client) {
|
||||
tensorflow::profiler::TraceMe("LocalExecutable::Compile");
|
||||
std::vector<const Shape*> argument_shape_pointers;
|
||||
argument_shape_pointers.reserve(argument_shapes.size());
|
||||
for (auto& argument_shape : argument_shapes) {
|
||||
argument_shape_pointers.push_back(&argument_shape);
|
||||
std::vector<const Shape*> argument_layout_pointers;
|
||||
argument_layout_pointers.reserve(argument_layouts.size());
|
||||
|
||||
// Assign a default layout to any array subshapes that are missing layouts.
|
||||
auto assign_layouts = [client](Shape* shape) {
|
||||
return ShapeUtil::ForEachMutableSubshapeWithStatus(
|
||||
shape, [&](Shape* subshape, const ShapeIndex&) {
|
||||
if (subshape->IsArray() && !subshape->has_layout()) {
|
||||
LayoutUtil::SetToDefaultLayout(subshape);
|
||||
TF_ASSIGN_OR_RETURN(*subshape,
|
||||
client->backend()
|
||||
.transfer_manager()
|
||||
->ChooseCompactLayoutForShape(*subshape));
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
};
|
||||
|
||||
for (Shape& layout : argument_layouts) {
|
||||
argument_layout_pointers.push_back(&layout);
|
||||
assign_layouts(&layout);
|
||||
}
|
||||
|
||||
ExecutableBuildOptions options;
|
||||
if (build_options != nullptr) {
|
||||
options = *build_options;
|
||||
}
|
||||
|
||||
Shape result_layout;
|
||||
if (options.result_layout()) {
|
||||
result_layout = *options.result_layout();
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
computation.GetProgramShape());
|
||||
result_layout = program_shape.result();
|
||||
LayoutUtil::ClearLayout(&result_layout);
|
||||
}
|
||||
assign_layouts(&result_layout);
|
||||
options.set_result_layout(result_layout);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto local_executable,
|
||||
client->Compile(computation, argument_shape_pointers, options));
|
||||
client->Compile(computation, argument_layout_pointers, options));
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
client->backend().computation_placer()->AssignDevices(
|
||||
options.num_replicas(), /*computation_count=*/1));
|
||||
|
@ -82,8 +82,7 @@ class LocalExecutableWrapper {
|
||||
public:
|
||||
// Compiles a computation to an executable.
|
||||
static StatusOr<std::unique_ptr<LocalExecutableWrapper>> Compile(
|
||||
const XlaComputation& computation,
|
||||
const std::vector<Shape>& argument_shapes,
|
||||
const XlaComputation& computation, std::vector<Shape> argument_layouts,
|
||||
const ExecutableBuildOptions* build_options, LocalClient* client);
|
||||
|
||||
LocalExecutableWrapper(std::unique_ptr<LocalExecutable> executable,
|
||||
|
@ -75,7 +75,9 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
if (layout) {
|
||||
return ShapeUtil::MakeShapeWithLayout(type, dims, *layout);
|
||||
} else {
|
||||
return ShapeUtil::MakeShape(type, dims);
|
||||
Shape shape = ShapeUtil::MakeShape(type, dims);
|
||||
shape.clear_layout();
|
||||
return shape;
|
||||
}
|
||||
},
|
||||
"Makes an array shape.", py::arg("type"), py::arg("dims"),
|
||||
@ -87,7 +89,9 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("tuple_shapes",
|
||||
static_cast<const std::vector<Shape>& (Shape::*)() const>(
|
||||
&Shape::tuple_shapes))
|
||||
.def("__repr__", [](const Shape& shape) { return shape.ToString(); });
|
||||
.def("__repr__", [](const Shape& shape) {
|
||||
return shape.ToString(/*print_layouts=*/true);
|
||||
});
|
||||
|
||||
py::class_<ProgramShape>(m, "ProgramShape")
|
||||
.def(py::init(
|
||||
|
@ -73,8 +73,7 @@ class Backend(object):
|
||||
"""Destructures a tuple buffer into a sequence of buffers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self, computation, argument_shapes, result_shape,
|
||||
compile_options):
|
||||
def compile(self, computation, compile_options):
|
||||
"""Compiles a computation. Returns an executable."""
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -115,15 +114,19 @@ class LocalBackend(Backend):
|
||||
def destructure_tuple(self, c_buffer):
|
||||
return c_buffer.DestructureTuple()
|
||||
|
||||
def compile(self, c_computation, argument_shapes, result_shape,
|
||||
compile_options):
|
||||
def compile(self, c_computation, compile_options):
|
||||
options = _xla.ExecutableBuildOptions()
|
||||
options.num_replicas = compile_options.num_replicas
|
||||
if compile_options.argument_layouts:
|
||||
argument_layouts = [
|
||||
s.as_xla_shape() for s in compile_options.argument_layouts
|
||||
]
|
||||
else:
|
||||
argument_layouts = c_computation.GetProgramShape().Parameters()
|
||||
if compile_options.result_layout:
|
||||
options.result_layout = compile_options.result_layout.as_xla_shape()
|
||||
argument_shapes = [s.as_xla_shape() for s in argument_shapes]
|
||||
return _xla.LocalExecutable.Compile(c_computation, argument_shapes, options,
|
||||
self.client)
|
||||
return _xla.LocalExecutable.Compile(c_computation, argument_layouts,
|
||||
options, self.client)
|
||||
|
||||
def delete_executable(self, executable):
|
||||
executable.Delete()
|
||||
@ -234,42 +237,6 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
|
||||
source_line=lineno)
|
||||
|
||||
|
||||
class PaddingType(enum.Enum):
|
||||
VALID = 1
|
||||
SAME = 2
|
||||
|
||||
|
||||
def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
|
||||
window_strides):
|
||||
"""Maps PaddingType or string to pad values (list of pairs of ints)."""
|
||||
if not isinstance(padding_type, (str, PaddingType)):
|
||||
msg = 'padding_type must be str or PaddingType, got {}.'
|
||||
raise TypeError(msg.format(type(padding_type)))
|
||||
|
||||
if isinstance(padding_type, str):
|
||||
if padding_type.upper() == 'VALID':
|
||||
padding_type = PaddingType.VALID
|
||||
elif padding_type.upper() == 'SAME':
|
||||
padding_type = PaddingType.SAME
|
||||
else:
|
||||
msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
|
||||
raise ValueError(msg.format(padding_type))
|
||||
|
||||
if padding_type == PaddingType.VALID:
|
||||
return [(0, 0)] * len(window_strides)
|
||||
elif padding_type == PaddingType.SAME:
|
||||
out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
|
||||
pad_sizes = [
|
||||
max((out_size - 1) * stride + filter_size - in_size, 0)
|
||||
for out_size, stride, filter_size, in_size in zip(
|
||||
out_shape, window_strides, rhs_dims, lhs_dims)
|
||||
]
|
||||
return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
|
||||
else:
|
||||
msg = 'Unexpected PaddingType value: {}'
|
||||
raise ValueError(msg.format(padding_type))
|
||||
|
||||
|
||||
PrimitiveType = _xla.PrimitiveType
|
||||
|
||||
XLA_ELEMENT_TYPE_TO_DTYPE = {
|
||||
@ -303,7 +270,7 @@ def dtype_to_etype(dtype):
|
||||
return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
|
||||
|
||||
|
||||
class LocalBuffer(object):
|
||||
class Buffer(object):
|
||||
"""Represents a handle to data owned by XLA.
|
||||
|
||||
The referent is ready for use in executing a local, compiled
|
||||
@ -322,7 +289,7 @@ class LocalBuffer(object):
|
||||
backend = backend or get_local_backend()
|
||||
pyval = require_numpy_array_layout(pyval)
|
||||
cbuf = backend.buffer_from_pyval(pyval, device)
|
||||
return LocalBuffer(cbuf, backend, device)
|
||||
return Buffer(cbuf, backend, device)
|
||||
|
||||
def to_py(self):
|
||||
return self.c_buffer.ToPython()
|
||||
@ -344,7 +311,7 @@ class LocalBuffer(object):
|
||||
result = self._backend.destructure_tuple(self.c_buffer)
|
||||
self.delete()
|
||||
return tuple(
|
||||
LocalBuffer(sub_buffer, device=self._device, backend=self._backend)
|
||||
Buffer(sub_buffer, device=self._device, backend=self._backend)
|
||||
for sub_buffer in result)
|
||||
|
||||
def is_deleted(self):
|
||||
@ -354,6 +321,11 @@ class LocalBuffer(object):
|
||||
self.delete()
|
||||
|
||||
|
||||
# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops
|
||||
# compatibility with Jaxlib versions older than 0.1.13.
|
||||
LocalBuffer = Buffer
|
||||
|
||||
|
||||
class Format(enum.IntEnum):
|
||||
"""Python copy of the Format protocol buffer enum."""
|
||||
INVALID_FORMAT = 0
|
||||
@ -396,6 +368,7 @@ class Shape(object):
|
||||
@staticmethod
|
||||
def from_pyval(pyval):
|
||||
"""Returns a Shape that describes a tuple-tree of Numpy arrays."""
|
||||
|
||||
def convert(pyval):
|
||||
if isinstance(pyval, tuple):
|
||||
return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
|
||||
@ -561,24 +534,6 @@ def require_numpy_array_layout(value):
|
||||
return np.require(value, requirements=['C', 'A'])
|
||||
|
||||
|
||||
class CompileOptions(object):
|
||||
"""Python object for XLA compile options.
|
||||
|
||||
These options can be passed to the 'compile' step when using a local XLA
|
||||
client.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.xla_dump_to = None
|
||||
self.dump_hlo_pass_re = None
|
||||
self.dump_hlo_module_re = None
|
||||
self.dump_hlo_as_text = None
|
||||
self.dump_hlo_as_proto = None
|
||||
self.hlo_profile = None
|
||||
self.num_replicas = 1
|
||||
self.result_layout = None
|
||||
|
||||
|
||||
def transfer_to_infeed(value, device_ordinal=0):
|
||||
"""Transfers the given value into the XLA infeed queue.
|
||||
|
||||
@ -611,8 +566,28 @@ def transfer_from_outfeed(shape, device_ordinal=0):
|
||||
"""
|
||||
# TODO(phawkins): support non-default backends.
|
||||
backend = get_local_backend()
|
||||
return backend.client.TransferFromOutfeed(shape.as_xla_shape(),
|
||||
device_ordinal)
|
||||
return backend.client.TransferFromOutfeed(
|
||||
shape.with_major_to_minor_layout_if_absent().as_xla_shape(),
|
||||
device_ordinal)
|
||||
|
||||
|
||||
class CompileOptions(object):
|
||||
"""Python object for XLA compile options.
|
||||
|
||||
These options can be passed to the 'compile' step when using a local XLA
|
||||
client.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.xla_dump_to = None
|
||||
self.dump_hlo_pass_re = None
|
||||
self.dump_hlo_module_re = None
|
||||
self.dump_hlo_as_text = None
|
||||
self.dump_hlo_as_proto = None
|
||||
self.hlo_profile = None
|
||||
self.num_replicas = 1
|
||||
self.argument_layouts = None
|
||||
self.result_layout = None
|
||||
|
||||
|
||||
class Computation(object):
|
||||
@ -656,55 +631,28 @@ class Computation(object):
|
||||
"""
|
||||
return self.computation.GetHloDotGraph()
|
||||
|
||||
def Compile(self,
|
||||
argument_shapes=(),
|
||||
compile_options=None,
|
||||
layout_fn=None,
|
||||
backend=None):
|
||||
def Compile(self, argument_shapes=None, compile_options=None, backend=None):
|
||||
"""Compiles a computation.
|
||||
|
||||
Computations are the result of a "ComputationBuild'ing" process.
|
||||
|
||||
Arguments:
|
||||
argument_shapes: parameter shapes -- they are first laid out by layout_fn
|
||||
if layout_fn is provided. Otherwise, the default layout for those shapes
|
||||
will be used.
|
||||
argument_shapes: Deprecated. Use compile_options.argument_layouts instead.
|
||||
compile_options: options to use for compilation, includes an optional laid
|
||||
out result shape for the computation.
|
||||
layout_fn: lambda that is used to lay out the argument/result shapes.
|
||||
backend: a `Backend` for which an executable should be generated.
|
||||
|
||||
Returns:
|
||||
A Executable instance.
|
||||
"""
|
||||
backend = backend or self._backend or get_local_backend()
|
||||
result_shape = _wrap_shape(self.computation.GetProgramShape().Result())
|
||||
|
||||
if layout_fn:
|
||||
argument_shapes = [
|
||||
shape.map_leaves(layout_fn) for shape in argument_shapes
|
||||
]
|
||||
result_shape = result_shape.map_leaves(layout_fn)
|
||||
|
||||
argument_shapes = list(argument_shapes)
|
||||
|
||||
compile_options = compile_options or CompileOptions()
|
||||
compile_options.result_layout = result_shape
|
||||
c = backend.compile(self.computation, argument_shapes, result_shape,
|
||||
compile_options)
|
||||
if argument_shapes:
|
||||
compile_options.argument_layouts = argument_shapes
|
||||
c = backend.compile(self.computation, compile_options)
|
||||
return Executable(c, backend=backend)
|
||||
|
||||
def CompileWithExampleArguments(self,
|
||||
arguments=(),
|
||||
compile_options=None,
|
||||
layout_fn=None,
|
||||
backend=None):
|
||||
return self.Compile(
|
||||
argument_shapes=[Shape.from_pyval(arg) for arg in arguments],
|
||||
compile_options=compile_options,
|
||||
layout_fn=layout_fn,
|
||||
backend=backend)
|
||||
|
||||
def GetProgramShape(self):
|
||||
return _wrap_program_shape(self._c_computation.GetProgramShape())
|
||||
|
||||
@ -725,25 +673,25 @@ class Executable(object):
|
||||
return self._device_ordinals
|
||||
|
||||
def Execute(self, arguments=(), check_for_deleted_args=True):
|
||||
"""Execute on one replica with LocalBuffer arguments and return value."""
|
||||
"""Execute on one replica with Buffer arguments and return value."""
|
||||
if check_for_deleted_args and any(arg.is_deleted() for arg in arguments):
|
||||
raise ValueError('Executing with deleted local buffer argument')
|
||||
raw_args = [arg.c_buffer for arg in arguments]
|
||||
output_buffer = self._backend.execute(self._c_executable, raw_args)
|
||||
return LocalBuffer(
|
||||
return Buffer(
|
||||
output_buffer, backend=self._backend, device=self._device_ordinals[0])
|
||||
|
||||
def ExecutePerReplica(self, arguments=None):
|
||||
"""Execute on many replicas with LocalBuffer arguments and return value.
|
||||
"""Execute on many replicas with Buffer arguments and return value.
|
||||
|
||||
Args:
|
||||
arguments: A sequence of sequences of LocalBuffers. The i'th inner
|
||||
sequence comprises the arguments for execution on the i'th replica.
|
||||
arguments: A sequence of sequences of Buffers. The i'th inner sequence
|
||||
comprises the arguments for execution on the i'th replica.
|
||||
|
||||
Returns:
|
||||
A list of the computation's outputs for each replica, as a LocalBuffer. If
|
||||
A list of the computation's outputs for each replica, as a Buffer. If
|
||||
a shallow sequence of arguments was passed in for `arguments`, then the
|
||||
sole, zero'th replica's output is returned instead, as a LocalBuffer.
|
||||
sole, zero'th replica's output is returned instead, as a Buffer.
|
||||
"""
|
||||
if arguments is None:
|
||||
arguments = ((),) * len(self._device_ordinals)
|
||||
@ -770,9 +718,9 @@ class Executable(object):
|
||||
output_buffers = self._backend.execute_replicated(self._c_executable,
|
||||
stripped_args)
|
||||
|
||||
# Wrap output handles in LocalBuffer instances
|
||||
# Wrap output handles in Buffer instances
|
||||
return tuple(
|
||||
LocalBuffer(
|
||||
Buffer(
|
||||
output_buffer,
|
||||
backend=self._backend,
|
||||
device=self._device_ordinals[replica])
|
||||
@ -782,7 +730,7 @@ class Executable(object):
|
||||
"""Execute on one replica with Python values as arguments and output."""
|
||||
|
||||
def put(arg):
|
||||
return LocalBuffer.from_pyval(
|
||||
return Buffer.from_pyval(
|
||||
arg, device=self._device_ordinals[0], backend=self._backend)
|
||||
|
||||
arguments = [put(arg) for arg in arguments]
|
||||
@ -792,7 +740,7 @@ class Executable(object):
|
||||
"""Execute on many replicas with Python values as arguments and output."""
|
||||
|
||||
def put(arg, device):
|
||||
return LocalBuffer.from_pyval(arg, device, backend=self._backend)
|
||||
return Buffer.from_pyval(arg, device, backend=self._backend)
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
arguments = [[
|
||||
@ -804,6 +752,42 @@ class Executable(object):
|
||||
self._backend.delete_executable(self._c_executable)
|
||||
|
||||
|
||||
class PaddingType(enum.Enum):
|
||||
VALID = 1
|
||||
SAME = 2
|
||||
|
||||
|
||||
def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
|
||||
window_strides):
|
||||
"""Maps PaddingType or string to pad values (list of pairs of ints)."""
|
||||
if not isinstance(padding_type, (str, PaddingType)):
|
||||
msg = 'padding_type must be str or PaddingType, got {}.'
|
||||
raise TypeError(msg.format(type(padding_type)))
|
||||
|
||||
if isinstance(padding_type, str):
|
||||
if padding_type.upper() == 'VALID':
|
||||
padding_type = PaddingType.VALID
|
||||
elif padding_type.upper() == 'SAME':
|
||||
padding_type = PaddingType.SAME
|
||||
else:
|
||||
msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
|
||||
raise ValueError(msg.format(padding_type))
|
||||
|
||||
if padding_type == PaddingType.VALID:
|
||||
return [(0, 0)] * len(window_strides)
|
||||
elif padding_type == PaddingType.SAME:
|
||||
out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
|
||||
pad_sizes = [
|
||||
max((out_size - 1) * stride + filter_size - in_size, 0)
|
||||
for out_size, stride, filter_size, in_size in zip(
|
||||
out_shape, window_strides, rhs_dims, lhs_dims)
|
||||
]
|
||||
return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
|
||||
else:
|
||||
msg = 'Unexpected PaddingType value: {}'
|
||||
raise ValueError(msg.format(padding_type))
|
||||
|
||||
|
||||
class ComputationBuilder(object):
|
||||
"""XLA computation builder.
|
||||
|
||||
@ -858,7 +842,9 @@ class ComputationBuilder(object):
|
||||
Returns:
|
||||
An XlaOp.
|
||||
"""
|
||||
return ops.Infeed(self._builder, shape.as_xla_shape())
|
||||
return ops.Infeed(
|
||||
self._builder,
|
||||
shape.with_major_to_minor_layout_if_absent().as_xla_shape())
|
||||
|
||||
def Outfeed(self, operand):
|
||||
"""Enqueues an outfeed op onto the computation.
|
||||
@ -955,8 +941,10 @@ class ComputationBuilder(object):
|
||||
if parameter_num is None:
|
||||
parameter_num = next(self._parameter_numbering)
|
||||
|
||||
return ops.Parameter(self._builder, parameter_num, shape.as_xla_shape(),
|
||||
name.encode('utf8'))
|
||||
return ops.Parameter(
|
||||
self._builder, parameter_num,
|
||||
shape.with_major_to_minor_layout_if_absent().as_xla_shape(),
|
||||
name.encode('utf8'))
|
||||
|
||||
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
|
||||
"""Enqueues a Parameter op onto the computation.
|
||||
@ -1143,9 +1131,11 @@ class ComputationBuilder(object):
|
||||
pads = _convert_padding_type_to_pad_values(
|
||||
padding,
|
||||
self.GetShape(operand).dimensions(), window_dimensions, window_strides)
|
||||
return ops.SelectAndScatterWithGeneralPadding(
|
||||
operand, select.computation, window_dimensions, window_strides, pads,
|
||||
source, init_value, scatter.computation)
|
||||
return ops.SelectAndScatterWithGeneralPadding(operand, select.computation,
|
||||
window_dimensions,
|
||||
window_strides, pads, source,
|
||||
init_value,
|
||||
scatter.computation)
|
||||
|
||||
def Slice(self, operand, start_indices, limit_indices, strides=None):
|
||||
"""Enqueues a slice operation onto the computation.
|
||||
@ -1306,13 +1296,15 @@ class ComputationBuilder(object):
|
||||
pads = _convert_padding_type_to_pad_values(
|
||||
padding,
|
||||
self.GetShape(operand).dimensions(), window_dimensions, window_strides)
|
||||
return ops.ReduceWindowWithGeneralPadding(
|
||||
operand, init_value, computation_to_apply.computation,
|
||||
window_dimensions, window_strides, (), (), pads)
|
||||
return ops.ReduceWindowWithGeneralPadding(operand, init_value,
|
||||
computation_to_apply.computation,
|
||||
window_dimensions, window_strides,
|
||||
(), (), pads)
|
||||
|
||||
def ReduceWindowWithGeneralPadding(
|
||||
self, operand, init_value, computation_to_apply, window_dimensions,
|
||||
window_strides, base_dilations, window_dilations, padding):
|
||||
def ReduceWindowWithGeneralPadding(self, operand, init_value,
|
||||
computation_to_apply, window_dimensions,
|
||||
window_strides, base_dilations,
|
||||
window_dilations, padding):
|
||||
"""Enqueues a windowed reduction operation onto the computation.
|
||||
|
||||
Args:
|
||||
@ -1328,10 +1320,11 @@ class ComputationBuilder(object):
|
||||
Returns:
|
||||
An XlaOp representing the added ReduceWindow op.
|
||||
"""
|
||||
return ops.ReduceWindowWithGeneralPadding(
|
||||
operand, init_value, computation_to_apply.computation,
|
||||
window_dimensions, window_strides, base_dilations, window_dilations,
|
||||
padding)
|
||||
return ops.ReduceWindowWithGeneralPadding(operand, init_value,
|
||||
computation_to_apply.computation,
|
||||
window_dimensions, window_strides,
|
||||
base_dilations, window_dilations,
|
||||
padding)
|
||||
|
||||
def RngNormal(self, mu, sigma, dims):
|
||||
"""Enqueues an RngNormal operation onto the computation.
|
||||
@ -1709,6 +1702,7 @@ def _forward_methods_to_local_builder():
|
||||
forward.__name__ = method_name
|
||||
setattr(ComputationBuilder, method_name, forward)
|
||||
|
||||
|
||||
_forward_methods_to_local_builder()
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ class ComputationTest(unittest.TestCase):
|
||||
return xla_client.ComputationBuilder(name)
|
||||
|
||||
def _Execute(self, c, arguments):
|
||||
compiled_c = c.Build().CompileWithExampleArguments(arguments)
|
||||
compiled_c = c.Build().Compile()
|
||||
return compiled_c.ExecuteWithPythonValues(arguments)
|
||||
|
||||
def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
|
||||
@ -53,11 +53,15 @@ class ComputationTest(unittest.TestCase):
|
||||
def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
|
||||
self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
|
||||
|
||||
def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7,
|
||||
def _ExecuteAndCompareClose(self,
|
||||
c,
|
||||
arguments=(),
|
||||
expected=None,
|
||||
rtol=1e-7,
|
||||
atol=0):
|
||||
self._ExecuteAndAssertWith(
|
||||
functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol),
|
||||
c, arguments, expected)
|
||||
functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), c,
|
||||
arguments, expected)
|
||||
|
||||
|
||||
def NumpyArrayF32(*args, **kwargs):
|
||||
@ -212,20 +216,19 @@ class ComputationsWithConstantsTest(ComputationTest):
|
||||
|
||||
def testShiftLeft(self):
|
||||
c = self._NewComputation()
|
||||
c.ShiftLeft(c.Constant(NumpyArrayS32([3])),
|
||||
c.Constant(NumpyArrayS32([2])))
|
||||
c.ShiftLeft(c.Constant(NumpyArrayS32([3])), c.Constant(NumpyArrayS32([2])))
|
||||
self._ExecuteAndCompareClose(c, expected=[12])
|
||||
|
||||
def testShiftRightArithmetic(self):
|
||||
c = self._NewComputation()
|
||||
c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])),
|
||||
c.Constant(NumpyArrayS32([1])))
|
||||
c.ShiftRightArithmetic(
|
||||
c.Constant(NumpyArrayS32([-2])), c.Constant(NumpyArrayS32([1])))
|
||||
self._ExecuteAndCompareClose(c, expected=[-1])
|
||||
|
||||
def testShiftRightLogical(self):
|
||||
c = self._NewComputation()
|
||||
c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])),
|
||||
c.Constant(NumpyArrayS32([1])))
|
||||
c.ShiftRightLogical(
|
||||
c.Constant(NumpyArrayS32([-1])), c.Constant(NumpyArrayS32([1])))
|
||||
self._ExecuteAndCompareClose(c, expected=[2**31 - 1])
|
||||
|
||||
def testSum2DF64(self):
|
||||
@ -396,7 +399,7 @@ class LocalBufferTest(ComputationTest):
|
||||
"""Tests focusing on execution with LocalBuffers."""
|
||||
|
||||
def _Execute(self, c, arguments):
|
||||
compiled_c = c.Build().CompileWithExampleArguments(arguments)
|
||||
compiled_c = c.Build().Compile()
|
||||
arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments]
|
||||
result_buffer = compiled_c.Execute(arg_buffers)
|
||||
return result_buffer.to_py()
|
||||
@ -410,24 +413,22 @@ class LocalBufferTest(ComputationTest):
|
||||
c = self._NewComputation()
|
||||
c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
|
||||
self._ExecuteAndCompareClose(
|
||||
c,
|
||||
arguments=[NumpyArrayF32(1.11)],
|
||||
expected=4.25)
|
||||
c, arguments=[NumpyArrayF32(1.11)], expected=4.25)
|
||||
|
||||
def testTwoParameterSum(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)),
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0.)))
|
||||
c.Add(
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0.)),
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0.)))
|
||||
self._ExecuteAndCompareClose(
|
||||
c,
|
||||
arguments=[NumpyArrayF32(1.11), NumpyArrayF32(3.14)],
|
||||
expected=4.25)
|
||||
c, arguments=[NumpyArrayF32(1.11),
|
||||
NumpyArrayF32(3.14)], expected=4.25)
|
||||
|
||||
def testCannotCallWithDeletedBuffers(self):
|
||||
c = self._NewComputation()
|
||||
c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
|
||||
arg = NumpyArrayF32(1.11)
|
||||
compiled_c = c.Build().CompileWithExampleArguments([arg])
|
||||
compiled_c = c.Build().Compile()
|
||||
arg_buffer = xla_client.LocalBuffer.from_pyval(arg)
|
||||
arg_buffer.delete()
|
||||
with self.assertRaises(ValueError):
|
||||
@ -452,8 +453,8 @@ class LocalBufferTest(ComputationTest):
|
||||
np.testing.assert_equal(want, got)
|
||||
|
||||
def testDestructureTupleTwoArrayElementDifferentType(self):
|
||||
t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
|
||||
np.array([2, 3, 4, 5], dtype=np.int32))
|
||||
t = (np.array([1.0, 2.0, 3.0, 4.0],
|
||||
dtype=np.float32), np.array([2, 3, 4, 5], dtype=np.int32))
|
||||
local_buffer = xla_client.LocalBuffer.from_pyval(t)
|
||||
pieces = local_buffer.destructure()
|
||||
self.assertTrue(local_buffer.is_deleted())
|
||||
@ -486,7 +487,7 @@ class LocalBufferTest(ComputationTest):
|
||||
pyval = np.array([[1., 2.]], np.float32)
|
||||
local_buffer = xla_client.LocalBuffer.from_pyval(pyval)
|
||||
xla_shape = local_buffer.shape()
|
||||
self.assertEqual(xla_shape.dimensions(), (1, 2,))
|
||||
self.assertEqual(xla_shape.dimensions(), (1, 2))
|
||||
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
|
||||
|
||||
|
||||
@ -500,18 +501,20 @@ class SingleOpTest(ComputationTest):
|
||||
|
||||
def testConcatenateF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Concatenate(
|
||||
(c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
|
||||
c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))),
|
||||
dimension=0)
|
||||
args = (
|
||||
c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
|
||||
c.Constant(NumpyArrayF32([4.0, 5.0, 6.0])),
|
||||
)
|
||||
c.Concatenate(args, dimension=0)
|
||||
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
|
||||
def testConcatenateF64(self):
|
||||
c = self._NewComputation()
|
||||
c.Concatenate(
|
||||
(c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
|
||||
c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))),
|
||||
dimension=0)
|
||||
args = (
|
||||
c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
|
||||
c.Constant(NumpyArrayF64([4.0, 5.0, 6.0])),
|
||||
)
|
||||
c.Concatenate(args, dimension=0)
|
||||
self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
|
||||
def testConvertElementType(self):
|
||||
@ -665,11 +668,13 @@ class SingleOpTest(ComputationTest):
|
||||
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
|
||||
lhs = a(1, 2, 3, 4)
|
||||
rhs = a(1, 2, 1, 2) * 10
|
||||
c.Conv(c.Constant(lhs), c.Constant(rhs),
|
||||
[1, 1], xla_client.PaddingType.SAME)
|
||||
result = np.array([[[[640., 700., 760., 300.],
|
||||
[880., 940., 1000., 380.],
|
||||
[1120., 1180., 1240., 460.]]]])
|
||||
c.Conv(
|
||||
c.Constant(lhs), c.Constant(rhs), [1, 1], xla_client.PaddingType.SAME)
|
||||
result = np.array([[[
|
||||
[640., 700., 760., 300.],
|
||||
[880., 940., 1000., 380.],
|
||||
[1120., 1180., 1240., 460.],
|
||||
]]])
|
||||
self._ExecuteAndCompareClose(c, expected=result)
|
||||
|
||||
def testConvF32Valid(self):
|
||||
@ -677,10 +682,12 @@ class SingleOpTest(ComputationTest):
|
||||
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
|
||||
lhs = a(1, 2, 3, 4)
|
||||
rhs = a(1, 2, 1, 2) * 10
|
||||
c.Conv(c.Constant(lhs), c.Constant(rhs),
|
||||
[2, 1], xla_client.PaddingType.VALID)
|
||||
result = np.array([[[[640., 700., 760.],
|
||||
[1120., 1180., 1240.]]]])
|
||||
c.Conv(
|
||||
c.Constant(lhs), c.Constant(rhs), [2, 1], xla_client.PaddingType.VALID)
|
||||
result = np.array([[[
|
||||
[640., 700., 760.],
|
||||
[1120., 1180., 1240.],
|
||||
]]])
|
||||
self._ExecuteAndCompareClose(c, expected=result)
|
||||
|
||||
def testConvWithGeneralPaddingF32(self):
|
||||
@ -692,12 +699,15 @@ class SingleOpTest(ComputationTest):
|
||||
pads = [(1, 0), (0, 1)]
|
||||
lhs_dilation = (2, 1)
|
||||
rhs_dilation = (1, 1)
|
||||
c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs),
|
||||
strides, pads, lhs_dilation, rhs_dilation)
|
||||
result = np.array([[[[0., 0., 0.],
|
||||
[10., 20., 0.],
|
||||
[0., 0., 0.],
|
||||
[40., 50., 0.]]]])
|
||||
c.ConvWithGeneralPadding(
|
||||
c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
|
||||
rhs_dilation)
|
||||
result = np.array([[[
|
||||
[0., 0., 0.],
|
||||
[10., 20., 0.],
|
||||
[0., 0., 0.],
|
||||
[40., 50., 0.],
|
||||
]]])
|
||||
self._ExecuteAndCompareClose(c, expected=result)
|
||||
|
||||
def testConvGeneralDilatedF32(self):
|
||||
@ -710,13 +720,15 @@ class SingleOpTest(ComputationTest):
|
||||
lhs_dilation = (2, 1)
|
||||
rhs_dilation = (1, 1)
|
||||
dimension_numbers = ("NCHW", "OIHW", "NCHW")
|
||||
c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
|
||||
strides, pads, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers)
|
||||
result = np.array([[[[0., 0., 0.],
|
||||
[10., 20., 0.],
|
||||
[0., 0., 0.],
|
||||
[40., 50., 0.]]]])
|
||||
c.ConvGeneralDilated(
|
||||
c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers)
|
||||
result = np.array([[[
|
||||
[0., 0., 0.],
|
||||
[10., 20., 0.],
|
||||
[0., 0., 0.],
|
||||
[40., 50., 0.],
|
||||
]]])
|
||||
self._ExecuteAndCompareClose(c, expected=result)
|
||||
|
||||
def testConvGeneralDilatedPermutedF32(self):
|
||||
@ -730,13 +742,10 @@ class SingleOpTest(ComputationTest):
|
||||
rhs_dilation = (1, 1)
|
||||
|
||||
dimension_numbers = ("NHWC", "OIHW", "CWNH")
|
||||
c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))),
|
||||
c.Constant(rhs),
|
||||
strides, pads, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers)
|
||||
result = np.array([[[[0., 0., 0.],
|
||||
[10., 20., 0.],
|
||||
[0., 0., 0.],
|
||||
c.ConvGeneralDilated(
|
||||
c.Constant(np.transpose(lhs, (0, 2, 3, 1))), c.Constant(rhs), strides,
|
||||
pads, lhs_dilation, rhs_dilation, dimension_numbers)
|
||||
result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.],
|
||||
[40., 50., 0.]]]])
|
||||
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
|
||||
|
||||
@ -751,17 +760,20 @@ class SingleOpTest(ComputationTest):
|
||||
rhs_dilation = (1, 1)
|
||||
dimension_numbers = ("NCHW", "OIHW", "NCHW")
|
||||
feature_group_count = 2
|
||||
c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
|
||||
strides, pads, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count)
|
||||
result = np.array([[[[0., 0., 0.],
|
||||
[10., 20., 0.],
|
||||
[0., 0., 0.],
|
||||
[40., 50., 0.]],
|
||||
[[0., 0., 0.],
|
||||
[330., 380., 160.],
|
||||
[0., 0., 0.],
|
||||
[480., 530., 220.]]]])
|
||||
c.ConvGeneralDilated(
|
||||
c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count)
|
||||
result = np.array([[[
|
||||
[0., 0., 0.],
|
||||
[10., 20., 0.],
|
||||
[0., 0., 0.],
|
||||
[40., 50., 0.],
|
||||
], [
|
||||
[0., 0., 0.],
|
||||
[330., 380., 160.],
|
||||
[0., 0., 0.],
|
||||
[480., 530., 220.],
|
||||
]]])
|
||||
self._ExecuteAndCompareClose(c, expected=result)
|
||||
|
||||
def testBooleanNot(self):
|
||||
@ -952,14 +964,11 @@ class SingleOpTest(ComputationTest):
|
||||
c = self._NewComputation()
|
||||
c.Pad(
|
||||
c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
|
||||
c.Constant(NumpyArrayF32(0.0)),
|
||||
[(1, 2, 1), (0, 1, 0)])
|
||||
self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0],
|
||||
[1.0, 2.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[3.0, 4.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0]])
|
||||
c.Constant(NumpyArrayF32(0.0)), [(1, 2, 1), (0, 1, 0)])
|
||||
self._ExecuteAndCompareClose(
|
||||
c,
|
||||
expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
|
||||
[3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
|
||||
def testPadWithPaddingConfig(self):
|
||||
c = self._NewComputation()
|
||||
@ -972,14 +981,11 @@ class SingleOpTest(ComputationTest):
|
||||
padding_config.dimensions.append(dimension)
|
||||
c.Pad(
|
||||
c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
|
||||
c.Constant(NumpyArrayF32(0.0)),
|
||||
padding_config)
|
||||
self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0],
|
||||
[1.0, 2.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[3.0, 4.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0]])
|
||||
c.Constant(NumpyArrayF32(0.0)), padding_config)
|
||||
self._ExecuteAndCompareClose(
|
||||
c,
|
||||
expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
|
||||
[3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
|
||||
def testReshape(self):
|
||||
c = self._NewComputation()
|
||||
@ -1102,8 +1108,10 @@ class SingleOpTest(ComputationTest):
|
||||
def testRngNormal(self):
|
||||
shape = (2, 3)
|
||||
c = self._NewComputation()
|
||||
c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)),
|
||||
dims=shape)
|
||||
c.RngNormal(
|
||||
c.Constant(NumpyArrayF32(0.)),
|
||||
c.Constant(NumpyArrayF32(1.)),
|
||||
dims=shape)
|
||||
result = c.Build().Compile().ExecuteWithPythonValues()
|
||||
# since the result is random, we just check shape and uniqueness
|
||||
self.assertEqual(result.shape, shape)
|
||||
@ -1113,8 +1121,10 @@ class SingleOpTest(ComputationTest):
|
||||
lo, hi = 2., 4.
|
||||
shape = (2, 3)
|
||||
c = self._NewComputation()
|
||||
c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)),
|
||||
dims=shape)
|
||||
c.RngUniform(
|
||||
c.Constant(NumpyArrayF32(lo)),
|
||||
c.Constant(NumpyArrayF32(hi)),
|
||||
dims=shape)
|
||||
result = c.Build().Compile().ExecuteWithPythonValues()
|
||||
# since the result is random, we just check shape, uniqueness, and range
|
||||
self.assertEqual(result.shape, shape)
|
||||
@ -1126,8 +1136,10 @@ class SingleOpTest(ComputationTest):
|
||||
lo, hi = 2, 4
|
||||
shape = (2, 3)
|
||||
c = self._NewComputation()
|
||||
c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)),
|
||||
dims=shape)
|
||||
c.RngUniform(
|
||||
c.Constant(NumpyArrayS32(lo)),
|
||||
c.Constant(NumpyArrayS32(hi)),
|
||||
dims=shape)
|
||||
result = c.Build().Compile().ExecuteWithPythonValues()
|
||||
# since the result is random, we just check shape, integrality, and range
|
||||
self.assertEqual(result.shape, shape)
|
||||
@ -1180,13 +1192,21 @@ class SingleOpTest(ComputationTest):
|
||||
dtype=np.float32)
|
||||
|
||||
c = self._NewComputation()
|
||||
c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False,
|
||||
lower=True, transpose_a=True)
|
||||
self._ExecuteAndCompareClose(c, expected=np.array([
|
||||
[0.5, 0.08333334, 0.04629629, 0.03367003],
|
||||
[2.5, -0.25, -0.1388889, -0.1010101],
|
||||
[4.5, -0.58333331, -0.32407406, -0.23569024],
|
||||
], dtype=np.float32), rtol=1e-4)
|
||||
c.TriangularSolve(
|
||||
c.Constant(a_vals),
|
||||
c.Constant(b_vals),
|
||||
left_side=False,
|
||||
lower=True,
|
||||
transpose_a=True)
|
||||
self._ExecuteAndCompareClose(
|
||||
c,
|
||||
expected=np.array([
|
||||
[0.5, 0.08333334, 0.04629629, 0.03367003],
|
||||
[2.5, -0.25, -0.1388889, -0.1010101],
|
||||
[4.5, -0.58333331, -0.32407406, -0.23569024],
|
||||
],
|
||||
dtype=np.float32),
|
||||
rtol=1e-4)
|
||||
|
||||
def testIsConstant(self):
|
||||
c = self._NewComputation()
|
||||
@ -1261,8 +1281,9 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
def _CreateMulF32ByParamComputation(self):
|
||||
"""Computation (f32) -> f32 that multiplies one parameter by the other."""
|
||||
c = self._NewComputation("mul_f32_by_param")
|
||||
c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)))
|
||||
c.Mul(
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)))
|
||||
return c.Build()
|
||||
|
||||
def _CreateMulF64By2Computation(self):
|
||||
@ -1326,15 +1347,17 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
def _CreateBinaryGeF32Computation(self):
|
||||
"""Computation (f32, f32) -> bool that tests first_param >= second_param."""
|
||||
c = self._NewComputation("param0_lt_param1")
|
||||
c.Ge(c.ParameterFromNumpy(NumpyArrayF32(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)))
|
||||
c.Ge(
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF32(0)))
|
||||
return c.Build()
|
||||
|
||||
def _CreateBinaryGeF64Computation(self):
|
||||
"""Computation (f64, f64) -> bool that tests first_param >= second_param."""
|
||||
c = self._NewComputation("param0_lt_param1")
|
||||
c.Ge(c.ParameterFromNumpy(NumpyArrayF64(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF64(0)))
|
||||
c.Ge(
|
||||
c.ParameterFromNumpy(NumpyArrayF64(0)),
|
||||
c.ParameterFromNumpy(NumpyArrayF64(0)))
|
||||
return c.Build()
|
||||
|
||||
def _MakeSample3DArrayF32(self):
|
||||
@ -1415,26 +1438,28 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
|
||||
def testSelectAndScatterF32(self):
|
||||
c = self._NewComputation()
|
||||
c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
|
||||
select=self._CreateBinaryGeF32Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 2),
|
||||
padding=xla_client.PaddingType.VALID,
|
||||
source=c.Constant(NumpyArrayF32([[0.1, 0.2]])),
|
||||
init_value=c.Constant(NumpyArrayF32(1)),
|
||||
scatter=self._CreateBinaryAddF32Computation())
|
||||
c.SelectAndScatter(
|
||||
c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
|
||||
select=self._CreateBinaryGeF32Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 2),
|
||||
padding=xla_client.PaddingType.VALID,
|
||||
source=c.Constant(NumpyArrayF32([[0.1, 0.2]])),
|
||||
init_value=c.Constant(NumpyArrayF32(1)),
|
||||
scatter=self._CreateBinaryAddF32Computation())
|
||||
self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
|
||||
|
||||
def testSelectAndScatterF64(self):
|
||||
c = self._NewComputation()
|
||||
c.SelectAndScatter(c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])),
|
||||
select=self._CreateBinaryGeF64Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 2),
|
||||
padding=xla_client.PaddingType.VALID,
|
||||
source=c.Constant(NumpyArrayF64([[0.1, 0.2]])),
|
||||
init_value=c.Constant(NumpyArrayF64(1)),
|
||||
scatter=self._CreateBinaryAddF64Computation())
|
||||
c.SelectAndScatter(
|
||||
c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])),
|
||||
select=self._CreateBinaryGeF64Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 2),
|
||||
padding=xla_client.PaddingType.VALID,
|
||||
source=c.Constant(NumpyArrayF64([[0.1, 0.2]])),
|
||||
init_value=c.Constant(NumpyArrayF64(1)),
|
||||
scatter=self._CreateBinaryAddF64Computation())
|
||||
self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
|
||||
|
||||
def testReduce1DtoScalarF32(self):
|
||||
@ -1537,61 +1562,73 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
def testReduceWindowValidUnitStridesF32(self):
|
||||
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.ReduceWindow(operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
window_dimensions=(2, 1), window_strides=(1, 1),
|
||||
padding=xla_client.PaddingType.VALID)
|
||||
c.ReduceWindow(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 1),
|
||||
padding=xla_client.PaddingType.VALID)
|
||||
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
|
||||
|
||||
def testReduceWindowSameUnitStridesF32(self):
|
||||
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.ReduceWindow(operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
window_dimensions=(2, 1), window_strides=(1, 1),
|
||||
padding=xla_client.PaddingType.SAME)
|
||||
c.ReduceWindow(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 1),
|
||||
padding=xla_client.PaddingType.SAME)
|
||||
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
|
||||
|
||||
def testReduceWindowValidGeneralStridesF32(self):
|
||||
input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.ReduceWindow(operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
window_dimensions=(2, 1), window_strides=(1, 2),
|
||||
padding=xla_client.PaddingType.VALID)
|
||||
c.ReduceWindow(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF32Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF32Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 2),
|
||||
padding=xla_client.PaddingType.VALID)
|
||||
self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
|
||||
|
||||
def testReduceWindowValidUnitStridesF64(self):
|
||||
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.ReduceWindow(operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
window_dimensions=(2, 1), window_strides=(1, 1),
|
||||
padding=xla_client.PaddingType.VALID)
|
||||
c.ReduceWindow(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 1),
|
||||
padding=xla_client.PaddingType.VALID)
|
||||
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
|
||||
|
||||
def testReduceWindowSameUnitStridesF64(self):
|
||||
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.ReduceWindow(operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
window_dimensions=(2, 1), window_strides=(1, 1),
|
||||
padding=xla_client.PaddingType.SAME)
|
||||
c.ReduceWindow(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 1),
|
||||
padding=xla_client.PaddingType.SAME)
|
||||
self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
|
||||
|
||||
def testReduceWindowValidGeneralStridesF64(self):
|
||||
input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
c = self._NewComputation()
|
||||
c.ReduceWindow(operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
window_dimensions=(2, 1), window_strides=(1, 2),
|
||||
padding=xla_client.PaddingType.VALID)
|
||||
c.ReduceWindow(
|
||||
operand=c.Constant(input_array),
|
||||
init_value=c.ConstantF64Scalar(0),
|
||||
computation_to_apply=self._CreateBinaryAddF64Computation(),
|
||||
window_dimensions=(2, 1),
|
||||
window_strides=(1, 2),
|
||||
padding=xla_client.PaddingType.VALID)
|
||||
self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
|
||||
|
||||
def testWhileF32(self):
|
||||
@ -1636,7 +1673,7 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
to_infeed = NumpyArrayS32([1, 2, 3, 4])
|
||||
c = self._NewComputation()
|
||||
c.Infeed(xla_client.Shape.from_pyval(to_infeed[0]))
|
||||
compiled_c = c.Build().CompileWithExampleArguments()
|
||||
compiled_c = c.Build().Compile()
|
||||
for item in to_infeed:
|
||||
xla_client.transfer_to_infeed(item)
|
||||
|
||||
@ -1650,7 +1687,7 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0]))
|
||||
c.Outfeed(x)
|
||||
|
||||
compiled_c = c.Build().CompileWithExampleArguments()
|
||||
compiled_c = c.Build().Compile()
|
||||
|
||||
for want in to_round_trip:
|
||||
execution = threading.Thread(target=compiled_c.Execute)
|
||||
@ -1673,8 +1710,9 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
dnums.index_vector_dim = 1
|
||||
|
||||
c = self._NewComputation()
|
||||
c.Scatter(c.Constant(a), c.Constant(scatter_indices), c.Constant(updates),
|
||||
self._CreateBinaryAddS32Computation(), dnums)
|
||||
c.Scatter(
|
||||
c.Constant(a), c.Constant(scatter_indices), c.Constant(updates),
|
||||
self._CreateBinaryAddS32Computation(), dnums)
|
||||
expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32)
|
||||
self._ExecuteAndCompareClose(c, expected=expected)
|
||||
|
||||
@ -1685,15 +1723,34 @@ class ErrorTest(ComputationTest):
|
||||
self.f32_scalar_2 = NumpyArrayF32(2.0)
|
||||
self.s32_scalar_2 = NumpyArrayS32(2)
|
||||
|
||||
def testCompileWithWrongElementTypeInLayout(self):
|
||||
c = self._NewComputation()
|
||||
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
|
||||
c.ParameterFromNumpy(self.s32_scalar_2)
|
||||
c.ClearOpMetadata()
|
||||
|
||||
options = xla_client.CompileOptions()
|
||||
options.argument_layouts = [xla_client.Shape.array_shape(np.float32, [])]
|
||||
|
||||
def TestFun():
|
||||
return c.Build().Compile(compile_options=options)
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
RuntimeError, r".*Invalid argument shape.*"
|
||||
r"expected s32\[\], got f32\[\].*", TestFun)
|
||||
|
||||
def testInvokeWithWrongElementType(self):
|
||||
c = self._NewComputation()
|
||||
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
|
||||
c.ParameterFromNumpy(self.s32_scalar_2)
|
||||
c.ClearOpMetadata()
|
||||
|
||||
def TestFun():
|
||||
return c.Build().Compile().ExecuteWithPythonValues([self.f32_scalar_2])
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
RuntimeError, r"Invalid argument shape.*xla_client_test.py.*"
|
||||
r"expected s32\[\], got f32\[\]",
|
||||
lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2]))
|
||||
RuntimeError, r"Invalid argument: Argument does not match.*"
|
||||
r"want s32\[\], got f32\[\].*", TestFun)
|
||||
|
||||
|
||||
class ComputationRootTest(ComputationTest):
|
||||
@ -1706,7 +1763,7 @@ class ComputationRootTest(ComputationTest):
|
||||
extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable
|
||||
|
||||
arg = NumpyArrayF32(1.0)
|
||||
compiled_c = c.Build(result).CompileWithExampleArguments([arg])
|
||||
compiled_c = c.Build(result).Compile()
|
||||
ans = compiled_c.ExecuteWithPythonValues([arg])
|
||||
np.testing.assert_allclose(ans, 4.14)
|
||||
|
||||
|
@ -77,14 +77,13 @@ class XrtBackend(xla_client.Backend):
|
||||
def destructure_tuple(self, c_buffer):
|
||||
return c_buffer.DestructureTuple()
|
||||
|
||||
def compile(self, computation, arg_shapes, result_shape, compile_options):
|
||||
del arg_shapes
|
||||
del result_shape
|
||||
def compile(self, computation, compile_options):
|
||||
# pylint: disable=protected-access
|
||||
program_shape = xla_client._wrap_program_shape(
|
||||
computation.GetProgramShape())
|
||||
# pylint: enable=protected-access
|
||||
proto = computation.GetSerializedProto()
|
||||
# TODO(phawkins): use the layouts in compile_options.
|
||||
arg_shapes = [
|
||||
_make_xla_shape(shape.with_major_to_minor_layout_if_absent())
|
||||
for shape in program_shape.parameter_shapes
|
||||
|
@ -2903,6 +2903,7 @@ cc_library(
|
||||
"hlo_pass_pipeline.h",
|
||||
],
|
||||
deps = [
|
||||
":compilation_stats",
|
||||
":dump",
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
@ -2913,7 +2914,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -3914,6 +3914,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "compilation_stats",
|
||||
srcs = ["compilation_stats.cc"],
|
||||
hdrs = ["compilation_stats.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dynamic_index_splitter",
|
||||
srcs = ["dynamic_index_splitter.cc"],
|
||||
|
@ -66,38 +66,16 @@ const absl::optional<std::set<int>>& BackendOptions::allowed_devices() const {
|
||||
return allowed_devices_;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
||||
public:
|
||||
explicit EigenThreadPoolWrapper(tensorflow::thread::ThreadPool* pool)
|
||||
: pool_(pool) {}
|
||||
~EigenThreadPoolWrapper() override {}
|
||||
|
||||
void Schedule(std::function<void()> fn) override {
|
||||
pool_->Schedule(std::move(fn));
|
||||
}
|
||||
int NumThreads() const override { return pool_->NumThreads(); }
|
||||
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
|
||||
|
||||
private:
|
||||
tensorflow::thread::ThreadPool* pool_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Define this in .cc file to avoid having to include eigen or forward declare
|
||||
// these types in the header.
|
||||
struct Backend::IntraOpThreadPool {
|
||||
explicit IntraOpThreadPool(const int num_threads)
|
||||
: pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(),
|
||||
"XLAEigen", num_threads)),
|
||||
wrapper(new EigenThreadPoolWrapper(pool.get())),
|
||||
device(new Eigen::ThreadPoolDevice(wrapper.get(),
|
||||
wrapper->NumThreads())) {}
|
||||
device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
|
||||
pool->NumThreads())) {}
|
||||
|
||||
std::unique_ptr<tensorflow::thread::ThreadPool> pool;
|
||||
std::unique_ptr<EigenThreadPoolWrapper> wrapper;
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> device;
|
||||
};
|
||||
|
||||
|
@ -23,6 +23,24 @@ namespace xla {
|
||||
StatusOr<bool>
|
||||
BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
|
||||
HloInstruction* batch_dot) {
|
||||
// This pass assumes the lhs and rhs batch dimensions are equal and strictly
|
||||
// ascending.
|
||||
const auto& is_iota = [](absl::Span<const int64> dims) {
|
||||
for (int64 i = 0; i < dims.size(); ++i) {
|
||||
if (dims[i] != i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (!absl::c_equal(
|
||||
batch_dot->dot_dimension_numbers().lhs_batch_dimensions(),
|
||||
batch_dot->dot_dimension_numbers().rhs_batch_dimensions()) ||
|
||||
!is_iota(AsInt64Slice(
|
||||
batch_dot->dot_dimension_numbers().lhs_batch_dimensions()))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers();
|
||||
HloInstruction *lhs = batch_dot->mutable_operand(0),
|
||||
*rhs = batch_dot->mutable_operand(1);
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -30,22 +31,15 @@ namespace xla {
|
||||
class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit BFloat16NormalizationVisitor(
|
||||
HloComputation* computation, const BFloat16Support* bfloat16_support,
|
||||
const BFloat16Support* bfloat16_support,
|
||||
BFloat16Normalization* bfloat16_normalization)
|
||||
: computation_(computation),
|
||||
: computation_(nullptr),
|
||||
bfloat16_support_(bfloat16_support),
|
||||
bfloat16_normalization_(bfloat16_normalization) {}
|
||||
|
||||
bool changed() const { return changed_; }
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
|
||||
static bool Run(HloComputation* computation,
|
||||
const BFloat16Support* bfloat16_support,
|
||||
BFloat16Normalization* bfloat16_normalization) {
|
||||
BFloat16NormalizationVisitor visitor(computation, bfloat16_support,
|
||||
bfloat16_normalization);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
Status Preprocess(HloInstruction* hlo) override;
|
||||
|
||||
private:
|
||||
// Checks if the HLO uses BF16 in an unsupported way, and if so, inserts
|
||||
@ -408,18 +402,21 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
return HandleInstruction(hlo);
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::Preprocess(HloInstruction* hlo) {
|
||||
computation_ = hlo->parent();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(
|
||||
2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
BFloat16NormalizationVisitor visitor(bfloat16_support_, this);
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_, this)) {
|
||||
changed = true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(comp->Accept(&visitor));
|
||||
}
|
||||
XLA_VLOG_LINES(2,
|
||||
"BFloat16Normalization::Run(), after:\n" + module->ToString());
|
||||
return changed;
|
||||
return visitor.changed();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -105,6 +105,8 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
|
||||
return operand_index == 0;
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
return operand_index == 0 || operand_index == 1;
|
||||
case HloOpcode::kGather:
|
||||
return operand_index == 0;
|
||||
case HloOpcode::kSelect:
|
||||
case HloOpcode::kTupleSelect:
|
||||
return operand_index == 1 || operand_index == 2;
|
||||
|
@ -277,8 +277,8 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
|
||||
// Constructor for CallGraph is private so absl::make_unique can't be used.
|
||||
auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
|
||||
|
||||
VLOG(2) << "Building call graph for:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
VLOG(3) << "Building call graph for:";
|
||||
XLA_VLOG_LINES(3, module->ToString());
|
||||
|
||||
// Construct nodes of the call graph and populate the callsites.
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
@ -309,7 +309,7 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
|
||||
call_graph->SetCallContexts();
|
||||
call_graph->SetNodeDepths();
|
||||
|
||||
XLA_VLOG_LINES(1, call_graph->ToString());
|
||||
XLA_VLOG_LINES(2, call_graph->ToString());
|
||||
|
||||
return call_graph;
|
||||
}
|
||||
|
132
tensorflow/compiler/xla/service/compilation_stats.cc
Normal file
132
tensorflow/compiler/xla/service/compilation_stats.cc
Normal file
@ -0,0 +1,132 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/compilation_stats.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class NoopStats : public CompilationStats {
|
||||
public:
|
||||
NoopStats() = default;
|
||||
|
||||
void StartPass(absl::string_view pass_name) override {}
|
||||
|
||||
void EndPass(absl::string_view pass_name) override {}
|
||||
|
||||
void CompilationReport() override {}
|
||||
};
|
||||
|
||||
class Stats : public CompilationStats {
|
||||
public:
|
||||
Stats() = default;
|
||||
|
||||
void StartPass(absl::string_view pass_name) override;
|
||||
|
||||
void EndPass(absl::string_view pass_name) override;
|
||||
|
||||
void CompilationReport() override;
|
||||
|
||||
private:
|
||||
struct PassInfo {
|
||||
PassInfo(absl::string_view name, double duration)
|
||||
: name(name), duration_ms(duration) {}
|
||||
|
||||
absl::string_view name;
|
||||
int num_runs = 1;
|
||||
double duration_ms;
|
||||
};
|
||||
|
||||
// Info about the passes that have been run so far.
|
||||
std::vector<PassInfo> passes_;
|
||||
// Used to avoid nested calls to StartPass.
|
||||
bool pass_running_ = false;
|
||||
absl::string_view current_pass_;
|
||||
// The start time of the currently running pass.
|
||||
uint64 start_micros_;
|
||||
};
|
||||
|
||||
/* static */
|
||||
std::unique_ptr<CompilationStats> CompilationStats::MakeNoopStats() {
|
||||
return absl::make_unique<NoopStats>();
|
||||
}
|
||||
|
||||
/* static */
|
||||
std::unique_ptr<CompilationStats> CompilationStats::MakeStats() {
|
||||
return absl::make_unique<Stats>();
|
||||
}
|
||||
|
||||
void Stats::StartPass(absl::string_view pass_name) {
|
||||
CHECK(!pass_running_) << "Can't start " << pass_name << " while running "
|
||||
<< current_pass_;
|
||||
pass_running_ = true;
|
||||
current_pass_ = pass_name;
|
||||
start_micros_ = tensorflow::Env::Default()->NowMicros();
|
||||
}
|
||||
|
||||
void Stats::EndPass(absl::string_view pass_name) {
|
||||
CHECK(pass_running_);
|
||||
CHECK_EQ(current_pass_, pass_name);
|
||||
pass_running_ = false;
|
||||
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
|
||||
double duration_ms = (end_micros - start_micros_) / 1000.0;
|
||||
passes_.push_back(PassInfo(current_pass_, duration_ms));
|
||||
}
|
||||
|
||||
void Stats::CompilationReport() {
|
||||
CHECK(!pass_running_) << "EndPass never called for " << current_pass_;
|
||||
absl::flat_hash_map<absl::string_view, PassInfo> summary;
|
||||
double total_duration = 0;
|
||||
|
||||
for (auto& pass_run : passes_) {
|
||||
auto pass_name = pass_run.name;
|
||||
total_duration += pass_run.duration_ms;
|
||||
auto it = summary.find(pass_name);
|
||||
if (it == summary.end()) {
|
||||
summary.insert(std::make_pair(pass_name, pass_run));
|
||||
} else {
|
||||
++summary.at(pass_name).num_runs;
|
||||
summary.at(pass_name).duration_ms += pass_run.duration_ms;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<PassInfo> sorted_summary;
|
||||
sorted_summary.reserve(summary.size());
|
||||
for (auto& it : summary) {
|
||||
sorted_summary.push_back(it.second);
|
||||
}
|
||||
absl::c_sort(sorted_summary, [](const PassInfo& a, const PassInfo& b) {
|
||||
// Sort passes that take the longest first, break ties using pass names.
|
||||
return std::make_pair(b.duration_ms, a.name) <
|
||||
std::make_pair(a.duration_ms, b.name);
|
||||
});
|
||||
LOG(INFO) << "Total runtime (ms) of HLO passes: " << total_duration;
|
||||
LOG(INFO) << "Pass name, num runs, time (ms)";
|
||||
for (auto& pass_info : sorted_summary) {
|
||||
LOG(INFO) << pass_info.name << ", " << pass_info.num_runs << ", "
|
||||
<< pass_info.duration_ms;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
48
tensorflow/compiler/xla/service/compilation_stats.h
Normal file
48
tensorflow/compiler/xla/service/compilation_stats.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_STATS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_STATS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// This class is used to collect information about HLO passes and print some
|
||||
// statistics at the end of compilation. From HloPassPipeline, we call StartPass
|
||||
// before the execution of a pass, and EndPass after. Currently, we only collect
|
||||
// timing information and how many times each pass was run. In the future, we
|
||||
// can add more things, such as the size of the HLO graph after each pass.
|
||||
class CompilationStats {
|
||||
public:
|
||||
virtual ~CompilationStats() = default;
|
||||
|
||||
static std::unique_ptr<CompilationStats> MakeNoopStats();
|
||||
|
||||
static std::unique_ptr<CompilationStats> MakeStats();
|
||||
|
||||
virtual void StartPass(absl::string_view pass_name) = 0;
|
||||
|
||||
virtual void EndPass(absl::string_view pass_name) = 0;
|
||||
|
||||
virtual void CompilationReport() = 0;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_STATS_H_
|
@ -209,7 +209,6 @@ void CompilerFunctor::AddOptimizationPasses(
|
||||
builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
|
||||
}
|
||||
|
||||
builder.DisableUnitAtATime = false;
|
||||
builder.DisableUnrollLoops = opt_level == 0;
|
||||
builder.LoopVectorize = opt_level > 0 && size_level == 0;
|
||||
builder.SLPVectorize = opt_level > 1 && size_level == 0;
|
||||
|
@ -275,6 +275,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddPass<CallInliner>();
|
||||
pipeline.AddPass<BatchDotSimplification>();
|
||||
pipeline.AddPass<DotDecomposer>();
|
||||
// After canonicalization, there may be more batch dots that can be
|
||||
// simplified.
|
||||
pipeline.AddPass<BatchDotSimplification>();
|
||||
auto cost_model = [](HloInstruction* conv) {
|
||||
// We need a cost model for CPUs. Currently, do nothing.
|
||||
return false;
|
||||
|
@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -101,8 +100,7 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
|
||||
} else {
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||
2);
|
||||
tensorflow::EigenThreadPoolWrapper tp(&pool);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_intra_op_thread_pool(&device);
|
||||
|
||||
|
@ -291,6 +291,7 @@ class DfsHloVisitorBase {
|
||||
// This call is purely a performance hint and can be omitted without
|
||||
// affecting correctness.
|
||||
void ReserveVisitStates(int num) { visit_state_.reserve(num); }
|
||||
size_t VisitStateSize() const { return visit_state_.size(); }
|
||||
|
||||
// Useful when we want to visit the same computation more than once with the
|
||||
// same visitor.
|
||||
|
@ -394,6 +394,7 @@ cc_library(
|
||||
":outfeed_manager",
|
||||
":partition_assignment",
|
||||
":stream_assignment",
|
||||
":stream_executor_util",
|
||||
":thunk",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -418,6 +419,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/kernels:gpu_utils",
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
|
||||
"//tensorflow/core/platform/default/build_config:cufft_plugin",
|
||||
@ -467,7 +469,9 @@ cc_library(
|
||||
":gpu_executable",
|
||||
":ir_emission_utils",
|
||||
":redzone_allocator",
|
||||
":stream_executor_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
@ -478,7 +482,9 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:logger",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/kernels:conv_ops",
|
||||
"//tensorflow/core/util/proto:proto_utils",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/time",
|
||||
|
@ -14,7 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
|
||||
|
||||
#include "google/protobuf/any.pb.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/time/time.h"
|
||||
@ -26,7 +28,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
@ -76,28 +81,6 @@ string NumBytesToString(int64 bytes) {
|
||||
bytes, "B)");
|
||||
}
|
||||
|
||||
// Acquires a process-global lock on the device pointed to by the given
|
||||
// StreamExecutor.
|
||||
//
|
||||
// This is used to prevent other XLA instances from trying to autotune on this
|
||||
// device while we're using it.
|
||||
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
// se::Platform*s are global singletons guaranteed to live forever.
|
||||
static auto* mutexes =
|
||||
new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
|
||||
tensorflow::mutex>();
|
||||
|
||||
tensorflow::mutex_lock global_lock(mu);
|
||||
auto it = mutexes
|
||||
->emplace(std::piecewise_construct,
|
||||
std::make_tuple(stream_exec->platform(),
|
||||
stream_exec->device_ordinal()),
|
||||
std::make_tuple())
|
||||
.first;
|
||||
return tensorflow::mutex_lock{it->second};
|
||||
}
|
||||
|
||||
tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
|
||||
tensorflow::CudnnVersion cudnn_version;
|
||||
if (auto* dnn = stream_executor->AsDnn()) {
|
||||
@ -179,33 +162,88 @@ bool CheckRedzones(const RedzoneAllocator& allocator, se::Stream* stream,
|
||||
return false;
|
||||
}
|
||||
|
||||
using ConvCacheKey =
|
||||
std::tuple<se::StreamExecutor*, std::string, std::string, Shape,
|
||||
std::vector<Shape>, std::string, std::string, int64>;
|
||||
|
||||
struct ConvCacheStats {
|
||||
int64 cache_hits = 0;
|
||||
int64 cache_misses = 0;
|
||||
|
||||
void LogStats() {
|
||||
VLOG(1) << "Cache hits: " << cache_hits;
|
||||
VLOG(1) << "Cache misses: " << cache_misses;
|
||||
}
|
||||
};
|
||||
|
||||
StatusOr<ConvCacheKey> AutotuneCacheKeyfromInstruction(
|
||||
const HloCustomCallInstruction* conv, se::StreamExecutor* se) {
|
||||
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
|
||||
conv->backend_config<CudnnConvBackendConfig>());
|
||||
std::vector<Shape> operand_shapes;
|
||||
absl::c_transform(conv->operands(), std::back_inserter(operand_shapes),
|
||||
[&](const HloInstruction* op) { return op->shape(); });
|
||||
|
||||
return std::make_tuple(
|
||||
se, backend_config.SerializeAsString(), conv->custom_call_target(),
|
||||
conv->shape(), std::move(operand_shapes),
|
||||
conv->window().SerializeAsString(),
|
||||
conv->convolution_dimension_numbers().SerializeAsString(),
|
||||
conv->feature_group_count());
|
||||
}
|
||||
|
||||
tensorflow::mutex autotune_cache_lock(tensorflow::LINKER_INITIALIZED);
|
||||
auto& autotune_cache GUARDED_BY(autotune_cache_lock) =
|
||||
*new absl::flat_hash_map<ConvCacheKey, AutotuneResult>();
|
||||
auto& autotune_cache_stats GUARDED_BY(autotune_cache_lock) =
|
||||
*new ConvCacheStats();
|
||||
} // anonymous namespace
|
||||
|
||||
// We could have caching here so that we don't redo this work for two identical
|
||||
// convolutions. Unfortunately our cache key would have to be a tuple
|
||||
// containing the protos passed to this function, and we have no utility for
|
||||
// hashing protos. We could write our own hash functions, but they'd silently
|
||||
// break if we ever added a field to one of the protos. Perhaps we could hack
|
||||
// using the binary-encoded proto as the hash key, on the assumption that two
|
||||
// protos being binary-equal is a sufficient, if not necessary, condition for
|
||||
// proper equality. But that would still leave us open to having unnecessary
|
||||
// cache misses and doing extra work. Overall, caching doesn't seem worth the
|
||||
// trouble, but we may want to revisit this if we ever find a model where
|
||||
// caching would speed up compilation a lot.
|
||||
StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
|
||||
const HloCustomCallInstruction* instr) {
|
||||
XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
|
||||
"CudnnConvAlgorithmPicker::PickBestAlgorithm for ", instr->ToString()));
|
||||
|
||||
const Shape& result_shape = instr->shape().tuple_shapes(0);
|
||||
|
||||
// Don't run this function concurrently on the same GPU.
|
||||
//
|
||||
// This is a bit of a hack and doesn't protect us against arbitrary concurrent
|
||||
// use of a GPU, but it's sufficient to let us compile two HLO modules
|
||||
// concurrently and then run them sequentially.
|
||||
//
|
||||
// Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
|
||||
// avoid ever doing duplicate work. If we have a cache miss, only one thread
|
||||
// will run PickBestAlgorithmImpl for a particular device.
|
||||
tensorflow::mutex_lock lock = LockGpu(stream_exec_);
|
||||
|
||||
// We cache the autotuning results to avoid doing the duplicate work,
|
||||
// which can greatly improve both stability (deterministic numeric results
|
||||
// within a process for a given input) and performance (2x speedup on some
|
||||
// models).
|
||||
TF_ASSIGN_OR_RETURN(ConvCacheKey key,
|
||||
AutotuneCacheKeyfromInstruction(instr, stream_exec_));
|
||||
{
|
||||
tensorflow::mutex_lock lock(autotune_cache_lock);
|
||||
auto it = autotune_cache.find(key);
|
||||
if (it != autotune_cache.end()) {
|
||||
autotune_cache_stats.cache_hits++;
|
||||
return it->second;
|
||||
}
|
||||
autotune_cache_stats.cache_misses++;
|
||||
}
|
||||
|
||||
StatusOr<AutotuneResult> result_or = PickBestAlgorithmNoCache(instr);
|
||||
if (result_or.ok()) {
|
||||
tensorflow::mutex_lock lock(autotune_cache_lock);
|
||||
CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
|
||||
}
|
||||
return result_or;
|
||||
}
|
||||
|
||||
StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
const HloCustomCallInstruction* instr) {
|
||||
XLA_SCOPED_LOGGING_TIMER(
|
||||
absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithmImpl for ",
|
||||
instr->ToString()));
|
||||
|
||||
const Shape& result_shape = instr->shape().tuple_shapes(0);
|
||||
|
||||
// Make sure any previous activity on this executor is done. We don't want to
|
||||
// interfere with programs that are still running on the GPU.
|
||||
if (!stream_exec_->SynchronizeAllActivity()) {
|
||||
@ -543,6 +581,12 @@ StatusOr<bool> CudnnConvAlgorithmPicker::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
|
||||
changed |= result;
|
||||
}
|
||||
|
||||
{
|
||||
tensorflow::mutex_lock lock(autotune_cache_lock);
|
||||
autotune_cache_stats.LogStats();
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
|
@ -52,6 +52,8 @@ class CudnnConvAlgorithmPicker : public HloModulePass {
|
||||
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
|
||||
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithm(
|
||||
const HloCustomCallInstruction* instr);
|
||||
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCache(
|
||||
const HloCustomCallInstruction* instr);
|
||||
|
||||
se::StreamExecutor* stream_exec_; // never null
|
||||
DeviceMemoryAllocator* allocator_; // may be null
|
||||
|
@ -364,7 +364,6 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
|
||||
params.output_buf = operand_buffers[1];
|
||||
break;
|
||||
case CudnnConvKind::kForwardActivation: {
|
||||
params.kind = CudnnConvKind::kForwardActivation;
|
||||
params.input_shape = &lhs_shape;
|
||||
params.filter_shape = &rhs_shape;
|
||||
params.output_shape = &conv_result_shape;
|
||||
|
@ -201,8 +201,8 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
|
||||
HloModule m
|
||||
|
||||
f1_computation {
|
||||
f1_p0 = f32[10]{0} parameter(0)
|
||||
ROOT f1_root = f32[10]{0} add(f1_p0, f1_p0)
|
||||
f1_p0 = f32[32]{0} parameter(0)
|
||||
ROOT f1_root = f32[32]{0} add(f1_p0, f1_p0)
|
||||
}
|
||||
|
||||
add_computation {
|
||||
@ -212,16 +212,16 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
|
||||
}
|
||||
|
||||
f2_computation {
|
||||
f2_p0 = f32[10]{0} parameter(0)
|
||||
f2_mul = f32[10]{0} multiply(f2_p0, f2_p0)
|
||||
f2_p0 = f32[32]{0} parameter(0)
|
||||
f2_mul = f32[32]{0} multiply(f2_p0, f2_p0)
|
||||
f2_zero = f32[] constant(0)
|
||||
ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0},
|
||||
to_apply=add_computation
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[10]{0} parameter(0)
|
||||
f1 = f32[10]{0} fusion(p0), kind=kLoop, calls=f1_computation
|
||||
p0 = f32[32]{0} parameter(0)
|
||||
f1 = f32[32]{0} fusion(p0), kind=kLoop, calls=f1_computation
|
||||
ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
@ -17,68 +17,42 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
namespace {
|
||||
using MatrixDescriptor = gemm_thunk_internal::MatrixDescriptor;
|
||||
|
||||
// Performs a gemm call without an explicit algorithm on lhs_matrix and
|
||||
// rhs_matrix, and stores the result to output_matrix.
|
||||
template <typename Element>
|
||||
bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, double alpha, double beta,
|
||||
se::Stream* stream) {
|
||||
DCHECK(!output_matrix.transpose);
|
||||
// This struct contains the metadata of a matrix, e.g., its base address and
|
||||
// dimensions.
|
||||
struct MatrixDescriptor {
|
||||
se::DeviceMemoryBase data;
|
||||
bool transpose; // Whether this matrix needs to be transposed.
|
||||
int64 num_rows;
|
||||
int64 num_cols;
|
||||
};
|
||||
|
||||
const int64 batch_size = lhs_matrix.batch_size;
|
||||
CHECK_EQ(batch_size, rhs_matrix.batch_size);
|
||||
CHECK_EQ(batch_size, output_matrix.batch_size);
|
||||
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
|
||||
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
|
||||
se::DeviceMemory<Element> output_data(output_matrix.data);
|
||||
using GemmCacheKey =
|
||||
std::tuple<PrimitiveType, bool, int64, int64, bool, int64, int64, double,
|
||||
double, se::blas::ComputationType, se::StreamExecutor*>;
|
||||
|
||||
auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
|
||||
: se::blas::Transpose::kNoTranspose;
|
||||
auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
|
||||
: se::blas::Transpose::kNoTranspose;
|
||||
auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
|
||||
tensorflow::mutex autotune_cache_mu(tensorflow::LINKER_INITIALIZED);
|
||||
auto& autotune_cache GUARDED_BY(autotune_cache_mu) = *new absl::flat_hash_map<
|
||||
GemmCacheKey, absl::optional<se::blas::AlgorithmType>>();
|
||||
int64 cache_hits GUARDED_BY(autotune_cache_mu) = 0;
|
||||
int64 cache_misses GUARDED_BY(autotune_cache_mu) = 0;
|
||||
|
||||
if (batch_size == 1) {
|
||||
return stream
|
||||
->ThenBlasGemm(
|
||||
lhs_transpose, rhs_transpose, output_matrix.num_rows,
|
||||
output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
|
||||
lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
|
||||
&output_data, /*leading dim of output=*/output_matrix.num_rows)
|
||||
.ok();
|
||||
}
|
||||
|
||||
int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
|
||||
int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
|
||||
int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
|
||||
return stream
|
||||
->ThenBlasGemmStridedBatched(
|
||||
lhs_transpose, rhs_transpose, output_matrix.num_rows,
|
||||
output_matrix.num_cols, /*size of reduce dim=*/k,
|
||||
/*alpha=*/alpha, lhs_data,
|
||||
/*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
|
||||
/*beta=*/beta, &output_data,
|
||||
/*leading dim of output=*/output_matrix.num_rows, output_stride,
|
||||
batch_size)
|
||||
.ok();
|
||||
}
|
||||
|
||||
// Like DoGemm, but takes an explicit computation type and algorithm.
|
||||
// Performs a gemm call on lhs_matrix and rhs_matrix, and stores the result
|
||||
// to output_matrix.
|
||||
//
|
||||
// computation_type specifies the type of intermediate values generated during
|
||||
// the matmul (e.g. your input/output matricies could be f16s but you could do
|
||||
// computations with f32s). algorithm is an opaque identifier which functions
|
||||
@ -93,19 +67,16 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
// the Stream was valid to begin with); check the is_valid property of the
|
||||
// ProfileResult to see whether the call actually succeeded.
|
||||
template <typename Element>
|
||||
bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
|
||||
bool DoGemmWithAlgorithm(int64 batch_size, MatrixDescriptor lhs_matrix,
|
||||
MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, double alpha,
|
||||
double beta,
|
||||
se::blas::ComputationType computation_type,
|
||||
se::blas::AlgorithmType algorithm, se::Stream* stream,
|
||||
se::Stream* stream,
|
||||
absl::optional<se::blas::AlgorithmType> algorithm,
|
||||
se::blas::ProfileResult* output_profile_result) {
|
||||
DCHECK(!output_matrix.transpose);
|
||||
|
||||
CHECK_EQ(1, lhs_matrix.batch_size);
|
||||
CHECK_EQ(1, rhs_matrix.batch_size);
|
||||
CHECK_EQ(1, output_matrix.batch_size);
|
||||
|
||||
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
|
||||
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
|
||||
se::DeviceMemory<Element> output_data(output_matrix.data);
|
||||
@ -116,17 +87,45 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
|
||||
: se::blas::Transpose::kNoTranspose;
|
||||
auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
|
||||
|
||||
return stream
|
||||
->ThenBlasGemmWithAlgorithm(
|
||||
lhs_transpose, rhs_transpose, output_matrix.num_rows,
|
||||
output_matrix.num_cols, /*size of reduce dim=*/k,
|
||||
/*alpha=*/static_cast<Element>(alpha), lhs_data,
|
||||
/*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows,
|
||||
/*beta=*/static_cast<Element>(beta), &output_data,
|
||||
/*leading dim of output=*/output_matrix.num_rows, computation_type,
|
||||
algorithm, output_profile_result)
|
||||
.ok();
|
||||
if (algorithm) {
|
||||
// Autotuning is disabled for batch_size != 1.
|
||||
CHECK_EQ(1, batch_size);
|
||||
return stream
|
||||
->ThenBlasGemmWithAlgorithm(
|
||||
lhs_transpose, rhs_transpose, output_matrix.num_rows,
|
||||
output_matrix.num_cols, /*size of reduce dim=*/k,
|
||||
/*alpha=*/static_cast<Element>(alpha), lhs_data,
|
||||
/*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows,
|
||||
/*beta=*/static_cast<Element>(beta), &output_data,
|
||||
/*leading dim of output=*/output_matrix.num_rows, computation_type,
|
||||
*algorithm, output_profile_result)
|
||||
.ok();
|
||||
} else if (batch_size == 1) {
|
||||
return stream
|
||||
->ThenBlasGemm(
|
||||
lhs_transpose, rhs_transpose, output_matrix.num_rows,
|
||||
output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
|
||||
lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
|
||||
&output_data, /*leading dim of output=*/output_matrix.num_rows)
|
||||
.ok();
|
||||
} else {
|
||||
int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
|
||||
int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
|
||||
int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
|
||||
return stream
|
||||
->ThenBlasGemmStridedBatched(
|
||||
lhs_transpose, rhs_transpose, output_matrix.num_rows,
|
||||
output_matrix.num_cols, /*size of reduce dim=*/k,
|
||||
/*alpha=*/alpha, lhs_data,
|
||||
/*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
|
||||
/*beta=*/beta, &output_data,
|
||||
/*leading dim of output=*/output_matrix.num_rows, output_stride,
|
||||
batch_size)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
// Experimentally tries to pick the best algorithm for the given gemm.
|
||||
@ -136,10 +135,43 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
|
||||
// than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
|
||||
// all.
|
||||
template <typename Element>
|
||||
StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
|
||||
MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, double alpha, double beta,
|
||||
se::blas::ComputationType computation_type, se::Stream* stream) {
|
||||
absl::optional<se::blas::AlgorithmType> DoUncachedGemmAutotune(
|
||||
PrimitiveType type, se::blas::ComputationType computation_type,
|
||||
int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, se::Stream* stream,
|
||||
const Shape output_shape, double alpha, double beta,
|
||||
absl::string_view instr_descr) {
|
||||
if (!stream->BlockHostUntilDone().ok()) {
|
||||
VLOG(2) << "Failed to synchronize GPU for autotuning";
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
VLOG(3) << "Starting autotune of GemmThunk " << instr_descr;
|
||||
|
||||
// If the output buffer already contains a bias then autotune into a
|
||||
// scratch buffer. This avoids overwriting the bias buffer. The scratch
|
||||
// buffer may contain arbitrary garbage values.
|
||||
se::DeviceMemoryBase scratch_data = output_matrix.data;
|
||||
absl::optional<se::ScopedDeviceMemory<char>> allocated_memory;
|
||||
if (beta != 0.0) {
|
||||
se::DeviceMemory<char> out = stream->parent()->AllocateArray<char>(
|
||||
ShapeUtil::ByteSizeOf(output_shape));
|
||||
|
||||
if (out.is_null()) {
|
||||
VLOG(1) << "Allocation failed, using generic algorthm";
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Destructor ensures deallocation at the end of the scope.
|
||||
allocated_memory.emplace(stream->parent(), out);
|
||||
scratch_data = out;
|
||||
}
|
||||
|
||||
const MatrixDescriptor scratch_descriptor{scratch_data,
|
||||
/*needs_transpose=*/false,
|
||||
output_matrix.num_rows,
|
||||
output_matrix.num_cols};
|
||||
|
||||
std::vector<se::blas::AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
|
||||
|
||||
@ -150,9 +182,9 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
|
||||
// for all algorithms if we're targeting < sm_50. But because we pass a
|
||||
// non-null ProfileResult, DoGemmWithAlgorithm should always return true,
|
||||
// and the actual success-ness is returned in ProfileResult::is_valid.
|
||||
CHECK(DoGemmWithAlgorithm<Element>(lhs_matrix, rhs_matrix, output_matrix,
|
||||
alpha, beta, computation_type, algorithm,
|
||||
stream, &profile_result));
|
||||
CHECK(DoGemmWithAlgorithm<Element>(
|
||||
/*batch_size=*/1, lhs_matrix, rhs_matrix, scratch_descriptor, alpha,
|
||||
beta, computation_type, stream, algorithm, &profile_result));
|
||||
|
||||
if (profile_result.is_valid()) {
|
||||
VLOG(3) << "cublas gemm algorithm " << algorithm << " took "
|
||||
@ -167,65 +199,14 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
|
||||
}
|
||||
|
||||
if (best_result.is_valid()) {
|
||||
VLOG(2) << "Autotune on GemmThunk " << instr_descr
|
||||
<< " successful; best algorithm is " << best_result.algorithm();
|
||||
return best_result.algorithm();
|
||||
}
|
||||
|
||||
return InternalError(
|
||||
"Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms "
|
||||
"ran successfully",
|
||||
stream, algorithms.size());
|
||||
}
|
||||
|
||||
// Helper functions to go from a PrimitiveType to a templated version of
|
||||
// DoGemm/DoGemmWithAlgorithm/DoGemmAutotune.
|
||||
auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
|
||||
switch (type) {
|
||||
case F16:
|
||||
return &DoGemm<Eigen::half>;
|
||||
case F32:
|
||||
return &DoGemm<float>;
|
||||
case F64:
|
||||
return &DoGemm<double>;
|
||||
case C64:
|
||||
return &DoGemm<std::complex<float>>;
|
||||
case C128:
|
||||
return &DoGemm<std::complex<double>>;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type.";
|
||||
}
|
||||
}
|
||||
auto GetGemmWithAlgorithmFn(PrimitiveType type)
|
||||
-> decltype(&DoGemmWithAlgorithm<float>) {
|
||||
switch (type) {
|
||||
case F16:
|
||||
return &DoGemmWithAlgorithm<Eigen::half>;
|
||||
case F32:
|
||||
return &DoGemmWithAlgorithm<float>;
|
||||
case F64:
|
||||
return &DoGemmWithAlgorithm<double>;
|
||||
case C64:
|
||||
return &DoGemmWithAlgorithm<std::complex<float>>;
|
||||
case C128:
|
||||
return &DoGemmWithAlgorithm<std::complex<double>>;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type.";
|
||||
}
|
||||
}
|
||||
auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
|
||||
switch (type) {
|
||||
case F16:
|
||||
return &DoGemmAutotune<Eigen::half>;
|
||||
case F32:
|
||||
return &DoGemmAutotune<float>;
|
||||
case F64:
|
||||
return &DoGemmAutotune<double>;
|
||||
case C64:
|
||||
return &DoGemmAutotune<std::complex<float>>;
|
||||
case C128:
|
||||
return &DoGemmAutotune<std::complex<double>>;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type.";
|
||||
}
|
||||
VLOG(1) << "Unable to autotune cuBLAS gemm on stream " << stream
|
||||
<< " none of the " << algorithms.size() << " ran successfully";
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Converts from an XLA PrimitiveType to a blas::ComputationType, which is used
|
||||
@ -250,6 +231,48 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
absl::optional<se::blas::AlgorithmType> DoGemmAutotune(
|
||||
int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, se::Stream* stream,
|
||||
const Shape output_shape, double alpha, double beta,
|
||||
absl::string_view instr_descr) {
|
||||
PrimitiveType type = output_shape.element_type();
|
||||
se::blas::ComputationType computation_type = GetBlasComputationType(type);
|
||||
|
||||
tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent());
|
||||
|
||||
GemmCacheKey key = std::make_tuple(
|
||||
type, lhs_matrix.transpose, lhs_matrix.num_rows, lhs_matrix.num_cols,
|
||||
rhs_matrix.transpose, rhs_matrix.num_rows, rhs_matrix.num_cols, alpha,
|
||||
beta, computation_type, stream->parent());
|
||||
|
||||
tensorflow::mutex_lock cache_lock(autotune_cache_mu);
|
||||
auto it = autotune_cache.find(key);
|
||||
int64 autotuning_requests = cache_hits + cache_misses;
|
||||
if (autotuning_requests && autotuning_requests % 10 == 0) {
|
||||
VLOG(2) << "Autotuning cache hits/(hits + misses): " << cache_hits << "/"
|
||||
<< autotuning_requests;
|
||||
}
|
||||
|
||||
if (it != autotune_cache.end()) {
|
||||
cache_hits++;
|
||||
VLOG(4)
|
||||
<< "Autotuning cache hit, using algorithm (-1 stands for 'generic'): "
|
||||
<< it->second.value_or(-1);
|
||||
return it->second;
|
||||
}
|
||||
cache_misses++;
|
||||
VLOG(4) << "Autotuning cache miss";
|
||||
|
||||
auto result = DoUncachedGemmAutotune<Element>(
|
||||
type, computation_type, batch_size, lhs_matrix, rhs_matrix, output_matrix,
|
||||
stream, output_shape, alpha, beta, instr_descr);
|
||||
|
||||
CHECK(autotune_cache.emplace(key, result).second);
|
||||
return result;
|
||||
}
|
||||
|
||||
DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
|
||||
if (hlo_instruction.opcode() == HloOpcode::kDot) {
|
||||
return hlo_instruction.dot_dimension_numbers();
|
||||
@ -270,6 +293,138 @@ DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
|
||||
return dot->dot_dimension_numbers();
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
Status ExecuteOnStreamParameterized(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler, const BufferAllocation::Slice lhs_buffer,
|
||||
const BufferAllocation::Slice rhs_buffer,
|
||||
const BufferAllocation::Slice output_buffer, const Shape lhs_shape,
|
||||
const Shape rhs_shape, const Shape output_shape,
|
||||
bool implements_whole_instruction, const HloInstruction* hlo_instruction,
|
||||
double alpha, double beta, bool xla_gpu_disable_autotune) {
|
||||
VLOG(2) << "Executing a GemmThunk";
|
||||
|
||||
se::DeviceMemoryBase lhs_data =
|
||||
buffer_allocations.GetDeviceAddress(lhs_buffer);
|
||||
se::DeviceMemoryBase rhs_data =
|
||||
buffer_allocations.GetDeviceAddress(rhs_buffer);
|
||||
se::DeviceMemoryBase output_data =
|
||||
buffer_allocations.GetDeviceAddress(output_buffer);
|
||||
|
||||
DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction);
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
|
||||
dim_nums.rhs_batch_dimensions_size());
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape.rank());
|
||||
|
||||
int64 row_dim = dim_nums.lhs_batch_dimensions_size();
|
||||
int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
|
||||
int64 batch_size = std::accumulate(output_shape.dimensions().begin(),
|
||||
output_shape.dimensions().end() - 2, 1,
|
||||
std::multiplies<int64>());
|
||||
|
||||
// Check that the batch dims don't cover the last two dims.
|
||||
for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
|
||||
CHECK_NE(row_dim, batch_dim);
|
||||
CHECK_NE(col_dim, batch_dim);
|
||||
}
|
||||
|
||||
// Verify that the non-batch dimensions are minor-most. This is required for
|
||||
// efficient access.
|
||||
for (const auto* shape : {&lhs_shape, &rhs_shape, &output_shape}) {
|
||||
CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
|
||||
CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
|
||||
}
|
||||
|
||||
// BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
|
||||
// matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
|
||||
// their layout. Therefore, we should treat dimension 0 as row and dimension 1
|
||||
// as column when mapping a matrix Dot to BLAS gemm.
|
||||
int64 output_num_rows = output_shape.dimensions(row_dim);
|
||||
int64 output_num_cols = output_shape.dimensions(col_dim);
|
||||
|
||||
// BLAS gemm expects the inputs and the output are in column-major order.
|
||||
// Therefore, we need to convert dot between row-major matrices to that
|
||||
// between column-major matrices. The key insight for the conversion is that,
|
||||
// in linear storage, matrix M in column-major order is identical to the
|
||||
// transpose of M in row-major order. In other words,
|
||||
//
|
||||
// column-major(M) = row-major(M^T).
|
||||
//
|
||||
// Leveraging this insight, we can perform dot between row-major matrices as
|
||||
// follows.
|
||||
//
|
||||
// row-major(C)
|
||||
// = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
|
||||
// = gemm(column-major(B^T), column-major(A^T))
|
||||
// = gemm(row-major(B), row-major(A))
|
||||
//
|
||||
// Although we do not modify the content of A and B in linear memory, we
|
||||
// should use the dimensions of B^T and A^T when calling gemm. For example,
|
||||
// the leading dimension of the LHS matrix of gemm is the number of rows in
|
||||
// B^T and thus the number of columns in B.
|
||||
auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape,
|
||||
bool transpose) -> MatrixDescriptor {
|
||||
bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
|
||||
bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
|
||||
LayoutUtil::Minor(output_shape.layout(), row_dim);
|
||||
return MatrixDescriptor{
|
||||
data, static_cast<bool>(transpose ^ layout_mismatch),
|
||||
shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
|
||||
shape.dimensions(row_dim + static_cast<int64>(!is_row_major))};
|
||||
};
|
||||
|
||||
MatrixDescriptor lhs_matrix = make_descriptor(
|
||||
lhs_data, lhs_shape, dim_nums.lhs_contracting_dimensions(0) == row_dim);
|
||||
MatrixDescriptor rhs_matrix = make_descriptor(
|
||||
rhs_data, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(
|
||||
implements_whole_instruction ? hlo_instruction : nullptr);
|
||||
|
||||
if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
|
||||
std::swap(lhs_matrix, rhs_matrix);
|
||||
std::swap(output_num_cols, output_num_rows);
|
||||
}
|
||||
|
||||
const MatrixDescriptor output_matrix{output_data, /*needs_transpose=*/false,
|
||||
output_num_rows, output_num_cols};
|
||||
|
||||
// Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts
|
||||
// to autotune this gemm to figure out the best algorithm.
|
||||
PrimitiveType element_type = output_shape.element_type();
|
||||
se::blas::ComputationType computation_type =
|
||||
GetBlasComputationType(element_type);
|
||||
|
||||
std::string instr_descr =
|
||||
hlo_instruction != nullptr ? hlo_instruction->ToString() : "<null>";
|
||||
|
||||
// Try finding the best algorithm by autotuning, or use older Gemm API
|
||||
// if autotuning is disabled or has failed.
|
||||
absl::optional<se::blas::AlgorithmType> best_algorithm;
|
||||
if (xla_gpu_disable_autotune) {
|
||||
VLOG(2) << "Autotuning disabled, using generic algorithm";
|
||||
} else if (batch_size != 1) {
|
||||
// TODO(b/112111608): Implement auto tune for batched gemm.
|
||||
VLOG(2) << "Batch size is non-singular, using generic algorithm";
|
||||
} else {
|
||||
// Autotune may fail for various reasons (e.g. when when CUDA 8 and GPU
|
||||
// sm_50 or older are used). In that case the returned best_algorithm
|
||||
// will be an empty optional.
|
||||
best_algorithm = DoGemmAutotune<Element>(
|
||||
batch_size, lhs_matrix, rhs_matrix, output_matrix, stream, output_shape,
|
||||
alpha, beta, instr_descr);
|
||||
}
|
||||
|
||||
bool launch_ok = DoGemmWithAlgorithm<Element>(
|
||||
batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha, beta,
|
||||
computation_type, stream, best_algorithm,
|
||||
/*output_profile_result=*/nullptr);
|
||||
|
||||
if (!launch_ok) {
|
||||
return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
|
||||
@ -290,196 +445,30 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
|
||||
beta_(beta),
|
||||
implements_whole_instruction_(implements_whole_instruction) {}
|
||||
|
||||
absl::optional<se::blas::AlgorithmType> GemmThunk::GetGemmAlgorithm(
|
||||
int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, se::DeviceMemoryBase output_data,
|
||||
se::Stream* stream) {
|
||||
PrimitiveType element_type = output_shape_.element_type();
|
||||
se::blas::ComputationType computation_type =
|
||||
GetBlasComputationType(element_type);
|
||||
|
||||
// TODO(b/112111608): Implement auto tune for batched gemm.
|
||||
if (batch_size != 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
if (GetModuleConfig().debug_options().xla_gpu_disable_autotune()) {
|
||||
VLOG(2) << "Auto-tune disabled, using generic algorithm.";
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
const string& device_name = stream->parent()->GetDeviceDescription().name();
|
||||
auto autotune_it = autotune_results_.find(device_name);
|
||||
if (autotune_it == autotune_results_.end()) {
|
||||
VLOG(3) << "Starting autotune of GemmThunk " << GetThunkName();
|
||||
|
||||
// If the output buffer already contains a bias then autotune into a
|
||||
// scratch buffer. This avoids overwriting the bias buffer. The scratch
|
||||
// buffer may contain arbitrary garbage values.
|
||||
se::DeviceMemoryBase scratch_data = output_data;
|
||||
std::unique_ptr<se::TemporaryDeviceMemory<char>> scratch_mem;
|
||||
if (beta_ != 0.0) {
|
||||
auto temp_status = stream->AllocateTemporaryArray<char>(
|
||||
ShapeUtil::ByteSizeOf(output_shape_));
|
||||
if (!temp_status.ok()) {
|
||||
return false;
|
||||
}
|
||||
scratch_mem = std::move(temp_status).ValueOrDie();
|
||||
scratch_data = scratch_mem->device_memory();
|
||||
}
|
||||
const MatrixDescriptor scratch_descriptor(
|
||||
scratch_data, false, output_matrix.num_rows, output_matrix.num_cols,
|
||||
batch_size);
|
||||
|
||||
StatusOr<se::blas::AlgorithmType> best_algorithm = GetGemmAutotuneFn(
|
||||
element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_, beta_,
|
||||
computation_type, stream);
|
||||
|
||||
autotune_it = autotune_results_.insert({device_name, best_algorithm}).first;
|
||||
|
||||
if (autotune_it->second.ok()) {
|
||||
VLOG(2) << "Autotune on GemmThunk " << GetThunkName()
|
||||
<< " successful; best algorithm is "
|
||||
<< best_algorithm.ValueOrDie();
|
||||
} else {
|
||||
VLOG(2) << "Autotune on GemmThunk " << GetThunkName()
|
||||
<< " unsuccessful. Will use generic gemm.";
|
||||
}
|
||||
}
|
||||
|
||||
if (autotune_it->second.ok()) {
|
||||
return autotune_it->second.ValueOrDie();
|
||||
}
|
||||
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
VLOG(2) << "Executing a GemmThunk";
|
||||
|
||||
se::DeviceMemoryBase lhs_data =
|
||||
buffer_allocations.GetDeviceAddress(lhs_buffer_);
|
||||
se::DeviceMemoryBase rhs_data =
|
||||
buffer_allocations.GetDeviceAddress(rhs_buffer_);
|
||||
se::DeviceMemoryBase output_data =
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_);
|
||||
|
||||
DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
|
||||
dim_nums.rhs_batch_dimensions_size());
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank());
|
||||
|
||||
int64 row_dim = dim_nums.lhs_batch_dimensions_size();
|
||||
int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
|
||||
int64 batch_size = std::accumulate(output_shape_.dimensions().begin(),
|
||||
output_shape_.dimensions().end() - 2, 1,
|
||||
std::multiplies<int64>());
|
||||
|
||||
// Check that the batch dims don't cover the last two dims.
|
||||
for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
|
||||
CHECK_NE(row_dim, batch_dim);
|
||||
CHECK_NE(col_dim, batch_dim);
|
||||
}
|
||||
|
||||
// Verify that the non-batch dimensions are minor-most. This is required for
|
||||
// efficient access.
|
||||
for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) {
|
||||
CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
|
||||
CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
|
||||
}
|
||||
|
||||
// BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
|
||||
// matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
|
||||
// their layout. Therefore, we should treat dimension 0 as row and dimension 1
|
||||
// as column when mapping a matrix Dot to BLAS gemm.
|
||||
int64 output_num_rows = output_shape_.dimensions(row_dim);
|
||||
int64 output_num_cols = output_shape_.dimensions(col_dim);
|
||||
|
||||
// BLAS gemm expects the inputs and the output are in column-major order.
|
||||
// Therefore, we need to convert dot between row-major matrices to that
|
||||
// between column-major matrices. The key insight for the conversion is that,
|
||||
// in linear storage, matrix M in column-major order is identical to the
|
||||
// transpose of M in row-major order. In other words,
|
||||
//
|
||||
// column-major(M) = row-major(M^T).
|
||||
//
|
||||
// Leveraging this insight, we can perform dot between row-major matrices as
|
||||
// follows.
|
||||
//
|
||||
// row-major(C)
|
||||
// = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
|
||||
// = gemm(column-major(B^T), column-major(A^T))
|
||||
// = gemm(row-major(B), row-major(A))
|
||||
//
|
||||
// Although we do not modify the content of A and B in linear memory, we
|
||||
// should use the dimensions of B^T and A^T when calling gemm. For example,
|
||||
// the leading dimension of the LHS matrix of gemm is the number of rows in
|
||||
// B^T and thus the number of columns in B.
|
||||
|
||||
auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape,
|
||||
bool transpose) -> MatrixDescriptor {
|
||||
bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
|
||||
bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
|
||||
LayoutUtil::Minor(output_shape_.layout(), row_dim);
|
||||
return MatrixDescriptor(
|
||||
data, transpose ^ layout_mismatch,
|
||||
shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
|
||||
shape.dimensions(row_dim + static_cast<int64>(!is_row_major)),
|
||||
batch_size);
|
||||
};
|
||||
|
||||
const MatrixDescriptor lhs_descriptor = make_descriptor(
|
||||
lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim);
|
||||
const MatrixDescriptor rhs_descriptor = make_descriptor(
|
||||
rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim);
|
||||
|
||||
// Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to
|
||||
// autotune this gemm to figure out the best algorithm.
|
||||
auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, se::Stream* stream) {
|
||||
PrimitiveType element_type = output_shape_.element_type();
|
||||
se::blas::ComputationType computation_type =
|
||||
GetBlasComputationType(element_type);
|
||||
absl::optional<se::blas::AlgorithmType> best_algorithm = GetGemmAlgorithm(
|
||||
batch_size, lhs_matrix, rhs_matrix, output_matrix, output_data, stream);
|
||||
|
||||
if (best_algorithm.has_value()) {
|
||||
auto algorithm = best_algorithm.value();
|
||||
VLOG(2) << "Using algorithm " << algorithm
|
||||
<< " chosen by autotuning on GemmThunk " << GetThunkName();
|
||||
return GetGemmWithAlgorithmFn(element_type)(
|
||||
lhs_matrix, rhs_matrix, output_matrix, alpha_, beta_,
|
||||
computation_type, algorithm, stream,
|
||||
/*output_profile_result=*/nullptr);
|
||||
auto fn = [&]() {
|
||||
switch (output_shape_.element_type()) {
|
||||
case F16:
|
||||
return &ExecuteOnStreamParameterized<Eigen::half>;
|
||||
case F32:
|
||||
return &ExecuteOnStreamParameterized<float>;
|
||||
case F64:
|
||||
return &ExecuteOnStreamParameterized<double>;
|
||||
case C64:
|
||||
return &ExecuteOnStreamParameterized<std::complex<float>>;
|
||||
case C128:
|
||||
return &ExecuteOnStreamParameterized<std::complex<double>>;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type.";
|
||||
}
|
||||
}();
|
||||
|
||||
// Autotune may fail for various reasons (e.g. when when CUDA 8 and GPU
|
||||
// sm_50 or older are used). Use the older Gemm API in these case.
|
||||
return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
|
||||
alpha_, beta_, stream);
|
||||
};
|
||||
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(
|
||||
implements_whole_instruction_ ? hlo_instruction() : nullptr);
|
||||
bool launch_ok;
|
||||
if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) {
|
||||
launch_ok = launch(lhs_descriptor, rhs_descriptor,
|
||||
MatrixDescriptor(output_data, false, output_num_rows,
|
||||
output_num_cols, batch_size),
|
||||
stream);
|
||||
} else {
|
||||
launch_ok = launch(rhs_descriptor, lhs_descriptor,
|
||||
MatrixDescriptor(output_data, false, output_num_cols,
|
||||
output_num_rows, batch_size),
|
||||
stream);
|
||||
}
|
||||
|
||||
if (!launch_ok) {
|
||||
return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
|
||||
}
|
||||
return Status::OK();
|
||||
return fn(buffer_allocations, stream, profiler, lhs_buffer_, rhs_buffer_,
|
||||
output_buffer_, lhs_shape_, rhs_shape_, output_shape_,
|
||||
implements_whole_instruction_, hlo_instruction(), alpha_, beta_,
|
||||
GetModuleConfig().debug_options().xla_gpu_disable_autotune());
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -29,29 +29,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
namespace gemm_thunk_internal {
|
||||
// Internal implementation details for GemmThunk:
|
||||
|
||||
// This struct contains the metadata of a matrix, e.g., its base address and
|
||||
// dimensions.
|
||||
struct MatrixDescriptor {
|
||||
MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose,
|
||||
int64 matrix_num_rows, int64 matrix_num_cols,
|
||||
int64 matrix_batch_size)
|
||||
: data(matrix_data),
|
||||
transpose(needs_transpose),
|
||||
num_rows(matrix_num_rows),
|
||||
num_cols(matrix_num_cols),
|
||||
batch_size(matrix_batch_size) {}
|
||||
|
||||
se::DeviceMemoryBase data;
|
||||
bool transpose; // Whether this matrix needs to be transposed.
|
||||
int64 num_rows;
|
||||
int64 num_cols;
|
||||
int64 batch_size;
|
||||
};
|
||||
} // namespace gemm_thunk_internal
|
||||
|
||||
// This class stores everything that StreamExecutor needs to launch a BLAS gemm.
|
||||
// It is generated by IrEmitter.
|
||||
//
|
||||
@ -76,21 +53,7 @@ class GemmThunk : public Thunk {
|
||||
se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
bool WillAutotuneKernel(se::Stream* stream) override {
|
||||
// We will autotune this kernel if we don't already have a autotune result
|
||||
// for the stream device.
|
||||
return autotune_results_.find(
|
||||
stream->parent()->GetDeviceDescription().name()) ==
|
||||
autotune_results_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::optional<se::blas::AlgorithmType> GetGemmAlgorithm(
|
||||
int64 batch_size, gemm_thunk_internal::MatrixDescriptor lhs_matrix,
|
||||
gemm_thunk_internal::MatrixDescriptor rhs_matrix,
|
||||
gemm_thunk_internal::MatrixDescriptor output_matrix,
|
||||
se::DeviceMemoryBase output_data, se::Stream* stream);
|
||||
|
||||
const BufferAllocation::Slice lhs_buffer_;
|
||||
const BufferAllocation::Slice rhs_buffer_;
|
||||
const BufferAllocation::Slice output_buffer_;
|
||||
@ -103,20 +66,6 @@ class GemmThunk : public Thunk {
|
||||
const double beta_;
|
||||
|
||||
const bool implements_whole_instruction_;
|
||||
|
||||
string GetThunkName() const {
|
||||
return hlo_instruction() != nullptr ? hlo_instruction()->ToString()
|
||||
: "<null>";
|
||||
}
|
||||
|
||||
// Maps device names (StreamExecutor::DeviceDescription::name()) to autotune
|
||||
// results. The map's value is the best algorithm we've found for this thunk
|
||||
// on this device, or an error if none of the algorithms worked and we should
|
||||
// use the regular gemm without an algorithm.
|
||||
//
|
||||
// TODO(b/112415150): Make this thread safe.
|
||||
std::unordered_map<string, StatusOr<se::blas::AlgorithmType>>
|
||||
autotune_results_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -132,13 +132,6 @@ Status GpuExecutable::ExecuteThunks(
|
||||
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
|
||||
}
|
||||
|
||||
// If this thunk is about to autotune then wait for all currently executing
|
||||
// thunks to finish. This reduces noise and thus the probability of
|
||||
// choosing a suboptimal algorithm.
|
||||
if (thunk->WillAutotuneKernel(stream)) {
|
||||
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
|
||||
}
|
||||
|
||||
VLOG(2) << "Executing the thunk for "
|
||||
<< thunk->hlo_instruction()->ToString() << " on stream "
|
||||
<< stream_no;
|
||||
|
@ -153,7 +153,8 @@ bool IsInputFusibleScatter(const HloInstruction& instr) {
|
||||
|
||||
bool IsInputFusible(const HloInstruction& instr) {
|
||||
// Input fusion only handles non-elemental reduction and scatter operations.
|
||||
return IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr);
|
||||
return instr.IsFusible() &&
|
||||
(IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr));
|
||||
}
|
||||
|
||||
bool IsLoopFusible(const HloInstruction& instr) {
|
||||
@ -163,29 +164,42 @@ bool IsLoopFusible(const HloInstruction& instr) {
|
||||
// compute the address of the GTE at the top of the kernel. Often we know the
|
||||
// address of the GTE result statically, so we can do this without chasing any
|
||||
// pointers.
|
||||
return (instr.IsElementwise() && instr.operand_count() > 0) ||
|
||||
instr.opcode() == HloOpcode::kBitcast ||
|
||||
instr.opcode() == HloOpcode::kBroadcast ||
|
||||
instr.opcode() == HloOpcode::kConcatenate ||
|
||||
instr.opcode() == HloOpcode::kDynamicSlice ||
|
||||
instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||
(instr.opcode() == HloOpcode::kFusion &&
|
||||
instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
|
||||
instr.opcode() == HloOpcode::kGather ||
|
||||
instr.opcode() == HloOpcode::kIota ||
|
||||
instr.opcode() == HloOpcode::kPad ||
|
||||
(instr.opcode() == HloOpcode::kReduce &&
|
||||
!IsReductionFromOrToContiguousDimensions(instr)) ||
|
||||
instr.opcode() == HloOpcode::kReduceWindow ||
|
||||
instr.opcode() == HloOpcode::kReshape ||
|
||||
instr.opcode() == HloOpcode::kReverse ||
|
||||
instr.opcode() == HloOpcode::kSlice ||
|
||||
instr.opcode() == HloOpcode::kTranspose;
|
||||
return instr.IsFusible() &&
|
||||
((instr.IsElementwise() && instr.operand_count() > 0) ||
|
||||
instr.opcode() == HloOpcode::kBitcast ||
|
||||
instr.opcode() == HloOpcode::kBroadcast ||
|
||||
instr.opcode() == HloOpcode::kConcatenate ||
|
||||
instr.opcode() == HloOpcode::kDynamicSlice ||
|
||||
instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||
(instr.opcode() == HloOpcode::kFusion &&
|
||||
instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
|
||||
instr.opcode() == HloOpcode::kGather ||
|
||||
instr.opcode() == HloOpcode::kIota ||
|
||||
instr.opcode() == HloOpcode::kPad ||
|
||||
(instr.opcode() == HloOpcode::kReduce &&
|
||||
!IsReductionFromOrToContiguousDimensions(instr)) ||
|
||||
instr.opcode() == HloOpcode::kReduceWindow ||
|
||||
instr.opcode() == HloOpcode::kReshape ||
|
||||
instr.opcode() == HloOpcode::kReverse ||
|
||||
instr.opcode() == HloOpcode::kSlice ||
|
||||
instr.opcode() == HloOpcode::kTranspose);
|
||||
}
|
||||
|
||||
bool IsFusible(const HloInstruction& instr) {
|
||||
return IsInputFusible(instr) || IsLoopFusible(instr);
|
||||
}
|
||||
|
||||
bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
|
||||
// We can fuse reduces and loop fusions. Elementwise instructions can be fused
|
||||
// with any other instruction.
|
||||
// Note that scatter cannot be the root of a multi-output fusion because
|
||||
// its emitter doesn't support it.
|
||||
|
||||
return instr.IsFusible() &&
|
||||
(IsInputFusibleReduction(instr) ||
|
||||
instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here.
|
||||
instr.IsElementwise());
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -69,6 +69,10 @@ bool IsInputFusibleScatter(const HloInstruction& instr);
|
||||
bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
||||
const HloInstruction& instr2);
|
||||
|
||||
// Whether `instr` is a candidate for sibling fusion or as a consumer in
|
||||
// a producer-consumer multi-output fusion.
|
||||
bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -469,11 +469,12 @@ TEST_F(GpuFusibleTest,
|
||||
TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) {
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
ENTRY reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
c0 = f32[] constant(0)
|
||||
exp = f32[2,2,2]{2,1,0} exponential(p0)
|
||||
reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add
|
||||
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp)
|
||||
exp = f32[32,32,32]{2,1,0} exponential(p0)
|
||||
reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
|
||||
to_apply=scalar_add
|
||||
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
const HloInstruction* reduce =
|
||||
@ -573,24 +574,28 @@ TEST_F(GpuFusibleTest,
|
||||
ShapesCompatibleForMultiOutputFusion_DifferentReduceDimensions) {
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_reduce_1 {
|
||||
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0.1 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
c0 = f32[] constant(0)
|
||||
ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} p0.1, f32[] c0), dimensions={0}, to_apply=scalar_add
|
||||
ROOT reduce = f32[32,32]{1,0} reduce(f32[32,32,32]{2,1,0} p0.1, f32[] c0),
|
||||
dimensions={0}, to_apply=scalar_add
|
||||
}
|
||||
|
||||
fused_reduce_2 {
|
||||
p0.2 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2)
|
||||
p0.2 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
mul = f32[32,32,32]{2,1,0} multiply(f32[32,32,32]{2,1,0} p0.2,
|
||||
f32[32,32,32]{2,1,0} p0.2)
|
||||
c1 = f32[] constant(0)
|
||||
ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={2}, to_apply=scalar_add
|
||||
ROOT reduce = f32[32,32]{1,0} reduce(f32[32,32,32]{2,1,0} mul, f32[] c1),
|
||||
dimensions={2}, to_apply=scalar_add
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||
reduce_1 = f32[2,2]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1
|
||||
reduce_2 = f32[2,2]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2
|
||||
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce_1, reduce_2)
|
||||
p0 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
p1 = f32[32,32,32]{2,1,0} parameter(1)
|
||||
reduce_1 = f32[32,32]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1
|
||||
reduce_2 = f32[32,32]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2
|
||||
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
|
||||
tuple(reduce_1, reduce_2)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
const HloInstruction* fusion_1 =
|
||||
@ -604,28 +609,31 @@ TEST_F(GpuFusibleTest,
|
||||
ShapesCompatibleForMultiOutputFusion_NoReductionToVector) {
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_element_wise {
|
||||
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||
ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
|
||||
p0.1 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
p1.1 = f32[32,32,32]{2,1,0} parameter(1)
|
||||
ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
|
||||
}
|
||||
|
||||
fused_reduce {
|
||||
p0.2 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2,
|
||||
f32[2,2,2]{2,1,0} p0.2)
|
||||
broadcast = f32[2,2,2,2]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
|
||||
p0.2 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
mul = f32[32,32,32]{2,1,0} multiply(f32[32,32,32]{2,1,0} p0.2,
|
||||
f32[32,32,32]{2,1,0} p0.2)
|
||||
broadcast = f32[32,32,32,32]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
|
||||
c1 = f32[] constant(0)
|
||||
// Note that reduce is not a reduction-to-vector.
|
||||
ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2,2]{3,2,1,0} broadcast,
|
||||
ROOT reduce = f32[32,32]{1,0} reduce(f32[32,32,32,32]{3,2,1,0} broadcast,
|
||||
f32[] c1), dimensions={1,3}, to_apply=scalar_add
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||
element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
|
||||
fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce
|
||||
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
|
||||
p0 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
p1 = f32[32,32,32]{2,1,0} parameter(1)
|
||||
element_wise = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop,
|
||||
calls=fused_element_wise
|
||||
fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(element_wise),
|
||||
kind=kLoop, calls=fused_reduce
|
||||
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
|
||||
tuple(fusion, element_wise)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
const HloInstruction* fusion_1 =
|
||||
|
@ -82,6 +82,45 @@ bool DotImplementedAsGemm(const HloInstruction& dot) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Given a shape and a group of contiguous dimensions in the shape, returns
|
||||
// a tuple of three values (major, middle, minor), where major is the size of
|
||||
// the dimensions more major then the given dimensions, minor is the size of
|
||||
// dimensions more minor then the given dimensions, and middle is the size of
|
||||
// the given dimensions.
|
||||
std::tuple<int64, int64, int64> PartitionShapeByMiddleDimensions(
|
||||
const Shape& shape, DimensionVector dims_middle) {
|
||||
CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
|
||||
|
||||
absl::Span<const int64> minor_to_major = LayoutUtil::MinorToMajor(shape);
|
||||
int64 values[3] = {1, 1, 1};
|
||||
enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
|
||||
Segment cur_segment = kMinor;
|
||||
|
||||
// Iterate through the dimensions for the three segments in the order of
|
||||
// minor, middle and major to accumulate the size of each segment.
|
||||
absl::c_for_each(minor_to_major, [&](int64 cur_dim) {
|
||||
if (cur_segment != kMajor) {
|
||||
// Handle change of segments.
|
||||
bool cur_dim_in_middle = absl::c_any_of(
|
||||
dims_middle, [&](int64 dim) { return dim == cur_dim; });
|
||||
if (cur_segment == kMinor) {
|
||||
if (cur_dim_in_middle) {
|
||||
cur_segment = kMiddle;
|
||||
}
|
||||
} else if (cur_segment == kMiddle) {
|
||||
if (!cur_dim_in_middle) {
|
||||
cur_segment = kMajor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
values[cur_segment] *= shape.dimensions(cur_dim);
|
||||
});
|
||||
|
||||
return std::make_tuple(values[kMajor], values[kMiddle], values[kMinor]);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ImplementedAsGemm(const HloInstruction& hlo) {
|
||||
@ -174,10 +213,74 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
|
||||
dims_to_keep.push_back(dim);
|
||||
}
|
||||
}
|
||||
return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||
dims_to_keep) ||
|
||||
LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||
if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||
dims_to_keep) &&
|
||||
!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||
reduce.dimensions())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_row_reduction;
|
||||
DimensionVector dims_in_elem;
|
||||
std::tie(is_row_reduction, dims_in_elem) =
|
||||
GetReductionKindAndContiguousComponents(input->shape(),
|
||||
reduce.dimensions());
|
||||
|
||||
if (is_row_reduction) {
|
||||
// For row reduction, the tile block is 1 x tile_size_x, and we are reducing
|
||||
// along tile_size_x which needs to be large enough to make the tiling
|
||||
// implementation efficient.
|
||||
return dims_in_elem[2] >= kWarpSize;
|
||||
}
|
||||
|
||||
// For column reduction, the tile block is tize_size_y x tile_size_x, and we
|
||||
// are reducing along tile_size_y. Both tile_size_x and tile_size_y need to be
|
||||
// large enough to make the tiling implementation efficient.
|
||||
return dims_in_elem[2] >= kWarpSize && dims_in_elem[1] >= kWarpSize;
|
||||
}
|
||||
|
||||
std::pair<bool, DimensionVector> GetReductionKindAndContiguousComponents(
|
||||
const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
|
||||
DimensionVector dims_to_keep;
|
||||
for (int64 dim = 0; dim < input_shape.rank(); ++dim) {
|
||||
if (!absl::c_linear_search(dims_to_reduce, dim)) {
|
||||
dims_to_keep.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
if (dims_to_keep.empty()) {
|
||||
return std::make_pair(
|
||||
true, DimensionVector{1, 1, ShapeUtil::ElementsIn(input_shape)});
|
||||
}
|
||||
|
||||
if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
|
||||
dims_to_keep)) {
|
||||
int64 num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1;
|
||||
std::tie(num_reduced_major, num_kept, num_reduced_minor) =
|
||||
PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
|
||||
if (num_kept == 1) {
|
||||
return std::make_pair(
|
||||
true, DimensionVector{1, 1, num_reduced_minor * num_reduced_major});
|
||||
}
|
||||
if (num_reduced_minor == 1) {
|
||||
return std::make_pair(false,
|
||||
DimensionVector{1, num_reduced_major, num_kept});
|
||||
}
|
||||
return std::make_pair(
|
||||
true, DimensionVector{num_reduced_major, num_kept, num_reduced_minor});
|
||||
}
|
||||
|
||||
int64 num_kept_major = 1, num_reduced = 1, num_kept_minor = 1;
|
||||
std::tie(num_kept_major, num_reduced, num_kept_minor) =
|
||||
PartitionShapeByMiddleDimensions(
|
||||
input_shape,
|
||||
DimensionVector(dims_to_reduce.begin(), dims_to_reduce.end()));
|
||||
if (num_kept_minor == 1) {
|
||||
return std::make_pair(true,
|
||||
DimensionVector{1, num_kept_major, num_reduced});
|
||||
}
|
||||
return std::make_pair(
|
||||
false, DimensionVector{num_kept_major, num_reduced, num_kept_minor});
|
||||
}
|
||||
|
||||
// This emits a device-side call to
|
||||
|
@ -152,6 +152,26 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo);
|
||||
// kept are contiguous in the input of the reduce instruction.
|
||||
bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce);
|
||||
|
||||
// Given the input shape and dimensions to reduce for a reduction, returns
|
||||
// <is_row_reduction, DimensionVector>:
|
||||
// is_row_reduction: indicates whether the reduction is a row reduction or a
|
||||
// column reduction.
|
||||
// DimensionVector: contains the size of the three contiguous components for the
|
||||
// reduction [depth, height, width]. For row reduction, height is the size of
|
||||
// the dimensions to keep, depth is the size of the dimensions to reduce that
|
||||
// are more major than the dimensions to keep, and width is the size of the
|
||||
// dimensions to reduce that are more minor than the dimensions to keep. For
|
||||
// column reduction, height is the size of dimensions to reduce, depth is the
|
||||
// the size of the dimensions to keep that are more major than the dimensions
|
||||
// to reduce, and width is the size of the dimensions to keep that are more
|
||||
// minor than the dimensions to reduce.
|
||||
//
|
||||
// Prerequisite: the reduction instruction passes the check
|
||||
// IsReductionFromOrToContiguousDimensions, which guarantees either the
|
||||
// dimensions to reduce or the dimensions to keep are consecutive.
|
||||
std::pair<bool, DimensionVector> GetReductionKindAndContiguousComponents(
|
||||
const Shape& input_shape, absl::Span<const int64> dims_to_reduce);
|
||||
|
||||
// Emits call to "vprintf" with given format and arguments.
|
||||
llvm::Value* EmitPrintf(absl::string_view fmt,
|
||||
absl::Span<llvm::Value* const> arguments,
|
||||
|
@ -3546,105 +3546,6 @@ Status AreFusedReductionOutputsConsistent(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Given a shape and a group of contiguous dimensions in the shape, returns
|
||||
// a tuple of three values (major, middle, minor), where major is the size of
|
||||
// the dimensions more major then the given dimensions, minor is the size of
|
||||
// dimensions more minor then the given dimensions, and middle is the size of
|
||||
// the given dimensions.
|
||||
std::tuple<int64, int64, int64> PartitionShapeByMiddleDimensions(
|
||||
const Shape& shape, DimensionVector dims_middle) {
|
||||
CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
|
||||
|
||||
absl::Span<const int64> minor_to_major = LayoutUtil::MinorToMajor(shape);
|
||||
int64 values[3] = {1, 1, 1};
|
||||
enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
|
||||
Segment cur_segment = kMinor;
|
||||
|
||||
// Iterate through the dimensions for the three segments in the order of
|
||||
// minor, middle and major to accumulate the size of each segment.
|
||||
absl::c_for_each(minor_to_major, [&](int64 cur_dim) {
|
||||
if (cur_segment != kMajor) {
|
||||
// Handle change of segments.
|
||||
bool cur_dim_in_middle = absl::c_any_of(
|
||||
dims_middle, [&](int64 dim) { return dim == cur_dim; });
|
||||
if (cur_segment == kMinor) {
|
||||
if (cur_dim_in_middle) {
|
||||
cur_segment = kMiddle;
|
||||
}
|
||||
} else if (cur_segment == kMiddle) {
|
||||
if (!cur_dim_in_middle) {
|
||||
cur_segment = kMajor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
values[cur_segment] *= shape.dimensions(cur_dim);
|
||||
});
|
||||
|
||||
return std::make_tuple(values[kMajor], values[kMiddle], values[kMinor]);
|
||||
}
|
||||
|
||||
// Given the input shape and dimensions to reduce for a reduction, returns
|
||||
// <is_row_reduction, DimensionVector>:
|
||||
// is_row_reduction: indicates whether the reduction is a row reduction or a
|
||||
// column reduction.
|
||||
// DimensionVector: contains the size of the three contiguous components for the
|
||||
// reduction [depth, height, width]. For row reduction, height is the size of
|
||||
// the dimensions to keep, depth is the size of the dimensions to reduce that
|
||||
// are more major than the dimensions to keep, and width is the size of the
|
||||
// dimensions to reduce that are more minor than the dimensions to keep. For
|
||||
// column reduction, height is the size of dimensions to reduce, depth is the
|
||||
// the size of the dimensions to keep that are more major than the dimensions
|
||||
// to reduce, and width is the size of the dimensions to keep that are more
|
||||
// minor than the dimensions to reduce.
|
||||
//
|
||||
// Prerequisite: the reduction instruction passes the check
|
||||
// IsReductionFromOrToContiguousDimensions, which guarantees either the
|
||||
// dimensions to reduce or the dimensions to keep are consecutive.
|
||||
std::pair<bool, DimensionVector> GetReductionKindAndContiguousComponents(
|
||||
const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
|
||||
DimensionVector dims_to_keep;
|
||||
for (int64 dim = 0; dim < input_shape.rank(); ++dim) {
|
||||
if (!absl::c_linear_search(dims_to_reduce, dim)) {
|
||||
dims_to_keep.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
if (dims_to_keep.empty()) {
|
||||
return std::make_pair(
|
||||
true, DimensionVector{1, 1, ShapeUtil::ElementsIn(input_shape)});
|
||||
}
|
||||
|
||||
if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
|
||||
dims_to_keep)) {
|
||||
int64 num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1;
|
||||
std::tie(num_reduced_major, num_kept, num_reduced_minor) =
|
||||
PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
|
||||
if (num_kept == 1) {
|
||||
return std::make_pair(
|
||||
true, DimensionVector{1, 1, num_reduced_minor * num_reduced_major});
|
||||
}
|
||||
if (num_reduced_minor == 1) {
|
||||
return std::make_pair(false,
|
||||
DimensionVector{1, num_reduced_major, num_kept});
|
||||
}
|
||||
return std::make_pair(
|
||||
true, DimensionVector{num_reduced_major, num_kept, num_reduced_minor});
|
||||
}
|
||||
|
||||
int64 num_kept_major = 1, num_reduced = 1, num_kept_minor = 1;
|
||||
std::tie(num_kept_major, num_reduced, num_kept_minor) =
|
||||
PartitionShapeByMiddleDimensions(
|
||||
input_shape,
|
||||
DimensionVector(dims_to_reduce.begin(), dims_to_reduce.end()));
|
||||
if (num_kept_minor == 1) {
|
||||
return std::make_pair(true,
|
||||
DimensionVector{1, num_kept_major, num_reduced});
|
||||
}
|
||||
return std::make_pair(
|
||||
false, DimensionVector{num_kept_major, num_reduced, num_kept_minor});
|
||||
}
|
||||
|
||||
// Returns true if all the transitive users of hlo before hitting users in
|
||||
// use_chain_endings are elementwise operations.
|
||||
bool AreUsersElementwise(const HloInstruction* hlo,
|
||||
|
@ -219,7 +219,6 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
|
||||
builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
|
||||
}
|
||||
|
||||
builder.DisableUnitAtATime = false;
|
||||
builder.DisableUnrollLoops = opt_level == 0;
|
||||
builder.LoopVectorize = opt_level > 0;
|
||||
builder.SLPVectorize = opt_level > 1 && size_level < 2;
|
||||
|
@ -45,11 +45,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
|
||||
}
|
||||
|
||||
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
|
||||
// We can fuse reduces and loop fusions. Elementwise instructions can be fused
|
||||
// with any other instruction.
|
||||
return instr->IsFusible() &&
|
||||
(IsInputFusibleReduction(*instr) || instr->IsLoopFusion() ||
|
||||
instr->IsElementwise());
|
||||
return IsFusibleAsMultiOutputFusionRoot(*instr);
|
||||
}
|
||||
|
||||
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
|
||||
|
@ -400,11 +400,12 @@ TEST_F(MultiOutputFusionTest,
|
||||
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
ENTRY reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
c0 = f32[] constant(0)
|
||||
exp = f32[2,2,2]{2,1,0} exponential(p0)
|
||||
reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation
|
||||
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp)
|
||||
exp = f32[32,32,32]{2,1,0} exponential(p0)
|
||||
reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
|
||||
to_apply=scalar_add_computation
|
||||
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||
@ -420,18 +421,19 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
|
||||
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_add {
|
||||
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||
ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
|
||||
p0.1 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
p1.1 = f32[32,32,32]{2,1,0} parameter(1)
|
||||
ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||
p0 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
p1 = f32[32,32,32]{2,1,0} parameter(1)
|
||||
c0 = f32[] constant(0)
|
||||
add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
|
||||
reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation
|
||||
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add)
|
||||
add = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
|
||||
reduce = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
|
||||
to_apply=scalar_add_computation
|
||||
ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, add)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||
@ -447,31 +449,37 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
|
||||
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_select {
|
||||
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||
p1.1 = f32[32,32,32]{2,1,0} parameter(1)
|
||||
c0 = f32[] constant(0)
|
||||
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
|
||||
greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT
|
||||
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
|
||||
broadcast = f32[32,32,32]{2,1,0} broadcast(f32[] c0), dimensions={}
|
||||
greater-than = pred[32,32,32]{2,1,0} compare(f32[32,32,32]{2,1,0} p1.1,
|
||||
f32[32,32,32]{2,1,0} broadcast), direction=GT
|
||||
p0.1 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
ROOT select = f32[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
|
||||
greater-than, f32[32,32,32]{2,1,0} p0.1, f32[32,32,32]{2,1,0} broadcast)
|
||||
}
|
||||
|
||||
fused_reduce {
|
||||
p0.2 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0.2 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
c1 = f32[] constant(0)
|
||||
r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation
|
||||
mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2)
|
||||
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
|
||||
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
|
||||
r1 = f32[32,32]{1,0} reduce(p0.2, c1), dimensions={2},
|
||||
to_apply=scalar_add_computation
|
||||
mul = f32[32,32,32]{2,1,0} multiply(p0.2, p0.2)
|
||||
r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
|
||||
to_apply=scalar_add_computation
|
||||
ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||
select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
|
||||
fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
|
||||
gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
|
||||
gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
|
||||
ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
|
||||
p0 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
p1 = f32[32,32,32]{2,1,0} parameter(1)
|
||||
select = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
|
||||
fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
|
||||
calls=fused_reduce
|
||||
gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
|
||||
gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
|
||||
ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
|
||||
tuple(gte1, gte1, select)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||
@ -518,30 +526,36 @@ TEST_F(MultiOutputFusionTest,
|
||||
ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
|
||||
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
|
||||
fused_select {
|
||||
p1.1 = f16[2,2,2]{2,1,0} parameter(1)
|
||||
p1.1 = f16[32,32,32]{2,1,0} parameter(1)
|
||||
c0 = f16[] constant(0)
|
||||
broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={}
|
||||
greater-than = pred[2,2,2]{2,1,0} compare(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast), direction=GT
|
||||
p0.1 = f16[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast)
|
||||
broadcast = f16[32,32,32]{2,1,0} broadcast(f16[] c0), dimensions={}
|
||||
greater-than = pred[32,32,32]{2,1,0} compare(f16[32,32,32]{2,1,0} p1.1,
|
||||
f16[32,32,32]{2,1,0} broadcast), direction=GT
|
||||
p0.1 = f16[32,32,32]{2,1,0} parameter(0)
|
||||
ROOT select = f16[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
|
||||
greater-than, f16[32,32,32]{2,1,0} p0.1, f16[32,32,32]{2,1,0} broadcast)
|
||||
}
|
||||
fused_reduce {
|
||||
p0.2 = f16[2,2,2]{2,1,0} parameter(0)
|
||||
convert = f32[2,2,2]{2,1,0} convert(p0.2)
|
||||
p0.2 = f16[32,32,32]{2,1,0} parameter(0)
|
||||
convert = f32[32,32,32]{2,1,0} convert(p0.2)
|
||||
c1 = f32[] constant(0)
|
||||
r1 = f32[2,2]{1,0} reduce(convert, c1), dimensions={2}, to_apply=scalar_add_computation
|
||||
mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
|
||||
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
|
||||
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
|
||||
r1 = f32[32,32]{1,0} reduce(convert, c1), dimensions={2},
|
||||
to_apply=scalar_add_computation
|
||||
mul = f32[32,32,32]{2,1,0} multiply(convert, convert)
|
||||
r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
|
||||
to_apply=scalar_add_computation
|
||||
ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
|
||||
}
|
||||
ENTRY reduce {
|
||||
p0 = f16[2,2,2]{2,1,0} parameter(0)
|
||||
p1 = f16[2,2,2]{2,1,0} parameter(1)
|
||||
select = f16[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
|
||||
fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
|
||||
gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
|
||||
gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
|
||||
ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
|
||||
p0 = f16[32,32,32]{2,1,0} parameter(0)
|
||||
p1 = f16[32,32,32]{2,1,0} parameter(1)
|
||||
select = f16[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
|
||||
fusion = (f32[32,32]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput,
|
||||
calls=fused_reduce
|
||||
gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
|
||||
gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
|
||||
ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f16[32,32,32]{2,1,0})
|
||||
tuple(gte1, gte1, select)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||
|
@ -162,5 +162,23 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
|
||||
|
||||
return std::make_tuple(input_layout, filter_layout, output_layout);
|
||||
}
|
||||
|
||||
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
// se::Platform*s are global singletons guaranteed to live forever.
|
||||
static auto* mutexes =
|
||||
new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
|
||||
tensorflow::mutex>();
|
||||
|
||||
tensorflow::mutex_lock global_lock(mu);
|
||||
auto it = mutexes
|
||||
->emplace(std::piecewise_construct,
|
||||
std::make_tuple(stream_exec->platform(),
|
||||
stream_exec->device_ordinal()),
|
||||
std::make_tuple())
|
||||
.first;
|
||||
return tensorflow::mutex_lock{it->second};
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -45,6 +45,14 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
|
||||
const Layout& input, const Layout& filter,
|
||||
const Layout& output);
|
||||
|
||||
// Generates and returns a unique lock per each provided executor.
|
||||
// Guarantees that blocks of code both holding a lock for the same provided
|
||||
// executor (as given by this function) will not be running concurrently.
|
||||
//
|
||||
// This is used to prevent other XLA instances from trying to autotune on a
|
||||
// device while another thread is using it.
|
||||
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -485,9 +485,9 @@ TEST_F(GpuKernelTilingTest,
|
||||
}
|
||||
|
||||
ENTRY kernel_entry {
|
||||
arg0 = f32[8,64,4]{2,1,0} parameter(0)
|
||||
arg0 = f32[8,64,32]{2,1,0} parameter(0)
|
||||
constant0 = f32[] constant(0)
|
||||
ROOT reduce0 = f32[8,4]{0,1} reduce(arg0, constant0), dimensions={1},
|
||||
ROOT reduce0 = f32[8,32]{0,1} reduce(arg0, constant0), dimensions={1},
|
||||
to_apply=reduction0
|
||||
})";
|
||||
|
||||
@ -506,6 +506,37 @@ TEST_F(GpuKernelTilingTest,
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
||||
}
|
||||
|
||||
TEST_F(GpuKernelTilingTest, RowReductionWithSmallDimensionNotTiled) {
|
||||
const char *const kHloString = R"(
|
||||
HloModule reduction
|
||||
reduction0 {
|
||||
x0 = f32[] parameter(0)
|
||||
y0 = f32[] parameter(1)
|
||||
ROOT add0 = f32[] add(x0, y0)
|
||||
}
|
||||
|
||||
ENTRY kernel_entry {
|
||||
arg0 = f32[8,6,16]{2,1,0} parameter(0)
|
||||
constant0 = f32[] constant(0)
|
||||
ROOT reduce0 = f32[8,6]{1,0} reduce(arg0, constant0), dimensions={2},
|
||||
to_apply=reduction0
|
||||
})";
|
||||
|
||||
// Check that the kernel is not tiled by looking for llvm.nvvm.shfl.sync.down.
|
||||
auto hlo_module =
|
||||
ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
|
||||
CompileAndVerifyIr(std::move(hlo_module),
|
||||
R"(
|
||||
; CHECK-LABEL: define void @reduce
|
||||
; CHECK-NOT: call float @llvm.nvvm.shfl.sync.down.f32
|
||||
; CHECK: }
|
||||
)",
|
||||
/*match_optimized_ir=*/true);
|
||||
|
||||
// Check that the kernel runs correctly.
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -111,8 +111,8 @@ TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) {
|
||||
hlo_module->AddEmbeddedComputation(embedded_builder.Build());
|
||||
}
|
||||
|
||||
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
|
||||
auto param_shape = ShapeUtil::MakeShape(F32, {32, 32});
|
||||
auto reduce_shape = ShapeUtil::MakeShape(F32, {32});
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, param_shape, "x"));
|
||||
HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
|
@ -85,10 +85,6 @@ class Thunk {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns true if this kernel will autotune for the stream device the next
|
||||
// time it is run.
|
||||
virtual bool WillAutotuneKernel(se::Stream* /*stream*/) { return false; }
|
||||
|
||||
// Execute the kernel for the thunk on the given stream. This method must be
|
||||
// called after Initialize and can be called multiple times over Thunk's
|
||||
// lifetime. 'stream' and 'profiler' must be non-null.
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -384,27 +385,11 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
|
||||
|
||||
StatusOr<HloInstruction*> ElideDegenerateDims(
|
||||
HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
|
||||
CHECK(absl::c_is_sorted(dims_to_elide));
|
||||
|
||||
const Shape& input_shape = operand->shape();
|
||||
// First accumulate in reverse
|
||||
std::vector<int64> new_shape_dim_bounds;
|
||||
new_shape_dim_bounds.reserve(input_shape.dimensions_size() -
|
||||
dims_to_elide.size());
|
||||
int64 dims_to_elide_idx = dims_to_elide.size() - 1;
|
||||
for (int64 i = input_shape.dimensions_size() - 1; i >= 0; i--) {
|
||||
if (dims_to_elide_idx >= 0 && i == dims_to_elide[dims_to_elide_idx]) {
|
||||
CHECK_EQ(input_shape.dimensions(i), 1);
|
||||
dims_to_elide_idx--;
|
||||
} else {
|
||||
new_shape_dim_bounds.push_back(input_shape.dimensions(i));
|
||||
}
|
||||
}
|
||||
|
||||
absl::c_reverse(new_shape_dim_bounds);
|
||||
Shape output_shape =
|
||||
ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds);
|
||||
return MakeReshapeHlo(output_shape, operand);
|
||||
return MakeReshapeHlo(
|
||||
ShapeUtil::FilterDimensions(
|
||||
[&](int64 dim) { return !absl::c_linear_search(dims_to_elide, dim); },
|
||||
operand->shape()),
|
||||
operand);
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> InsertDegenerateDims(
|
||||
|
@ -957,7 +957,6 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
|
||||
//
|
||||
// Returns true if:
|
||||
//
|
||||
// * fusion is a loop or input fusion, AND
|
||||
// * fusion_param is used by the root of dynamic-update-slice as the "base" of
|
||||
// the update, i.e. the thing being updated, AND
|
||||
// * all other uses of fusion_param are dynamic-slices that slice the same
|
||||
@ -977,13 +976,6 @@ static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
|
||||
CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
|
||||
CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
|
||||
|
||||
// fusion must be a loop or input fusion.
|
||||
auto kind = fusion->fusion_kind();
|
||||
if (kind != HloInstruction::FusionKind::kLoop &&
|
||||
kind != HloInstruction::FusionKind::kInput) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// fusion_param must be used by the root as the "base" of the
|
||||
// dynamic-update-slice. The natural way to check this would be
|
||||
//
|
||||
|
@ -423,10 +423,12 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
|
||||
#ifndef NDEBUG
|
||||
const Literal* input_literal = arg_literals_[parameter->parameter_number()];
|
||||
VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
|
||||
DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape()))
|
||||
<< "parameter shape is: " << ShapeUtil::HumanString(parameter->shape())
|
||||
DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
|
||||
input_literal->shape()))
|
||||
<< "parameter shape is: "
|
||||
<< ShapeUtil::HumanStringWithLayout(parameter->shape())
|
||||
<< ", but input literal shape is: "
|
||||
<< ShapeUtil::HumanString(input_literal->shape());
|
||||
<< ShapeUtil::HumanStringWithLayout(input_literal->shape());
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
|
@ -2751,7 +2751,12 @@ template <typename Visitor>
|
||||
static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
|
||||
const InternalCompareFunction* operand_order,
|
||||
bool ignore_control_predecessors) {
|
||||
visitor->ReserveVisitStates(root->GetModule()->instruction_count());
|
||||
// Calculating the instruction count within a module can be expensive on large
|
||||
// models so only do it if the visit state is empty. This will help when the
|
||||
// same visitor is reused across many computations of a single module.
|
||||
if (visitor->VisitStateSize() == 0) {
|
||||
visitor->ReserveVisitStates(root->GetModule()->instruction_count());
|
||||
}
|
||||
|
||||
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
|
||||
//
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
@ -122,8 +123,9 @@ Status HloModuleGroupMetadata::Build() {
|
||||
// Visit the computations in postorder so that the companion information grows
|
||||
// from inner computations to outer ones.
|
||||
for (HloModule* module : modules_) {
|
||||
FunctionVisitor function_visitor(visitor);
|
||||
for (HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||
TF_RETURN_IF_ERROR(computation->Accept(visitor));
|
||||
TF_RETURN_IF_ERROR(computation->Accept(&function_visitor));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyCompanionSets());
|
||||
@ -370,8 +372,9 @@ Status HloModuleGroupMetadata::RecordInstructions() {
|
||||
};
|
||||
|
||||
for (HloModule* module : modules_) {
|
||||
FunctionVisitor function_visitor(visitor);
|
||||
for (auto* computation : module->computations()) {
|
||||
TF_RETURN_IF_ERROR(computation->Accept(visitor));
|
||||
TF_RETURN_IF_ERROR(computation->Accept(&function_visitor));
|
||||
}
|
||||
}
|
||||
VLOG(2) << "Created " << channels_.size() << " channels";
|
||||
|
@ -2671,7 +2671,7 @@ bool HloParser::ParseAttributeHelper(
|
||||
if (!ParseAttributeName(&name)) {
|
||||
return Error(loc, "error parsing attributes");
|
||||
}
|
||||
VLOG(1) << "Parsing attribute " << name;
|
||||
VLOG(3) << "Parsing attribute " << name;
|
||||
if (!seen_attrs->insert(name).second) {
|
||||
return Error(loc, StrFormat("attribute %s already exists", name));
|
||||
}
|
||||
@ -2943,7 +2943,7 @@ bool HloParser::ParseAttributeAsProtoMessageHelper(
|
||||
if (!ParseAttributeName(&name)) {
|
||||
return Error(loc, "error parsing attributes");
|
||||
}
|
||||
VLOG(1) << "Parsing attribute " << name;
|
||||
VLOG(3) << "Parsing attribute " << name;
|
||||
if (!seen_attrs->insert(name).second) {
|
||||
return Error(loc, StrFormat("attribute %s already exists", name));
|
||||
}
|
||||
@ -3650,7 +3650,7 @@ bool HloParser::CanBeShape() {
|
||||
}
|
||||
|
||||
bool HloParser::ParseName(string* result) {
|
||||
VLOG(1) << "ParseName";
|
||||
VLOG(3) << "ParseName";
|
||||
if (lexer_.GetKind() != TokKind::kIdent &&
|
||||
lexer_.GetKind() != TokKind::kName) {
|
||||
return TokenError("expects name");
|
||||
@ -3670,7 +3670,7 @@ bool HloParser::ParseAttributeName(string* result) {
|
||||
}
|
||||
|
||||
bool HloParser::ParseString(string* result) {
|
||||
VLOG(1) << "ParseString";
|
||||
VLOG(3) << "ParseString";
|
||||
if (lexer_.GetKind() != TokKind::kString) {
|
||||
return TokenError("expects string");
|
||||
}
|
||||
@ -3784,7 +3784,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) {
|
||||
}
|
||||
|
||||
bool HloParser::ParseOpcode(HloOpcode* result) {
|
||||
VLOG(1) << "ParseOpcode";
|
||||
VLOG(3) << "ParseOpcode";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects opcode");
|
||||
}
|
||||
@ -3800,7 +3800,7 @@ bool HloParser::ParseOpcode(HloOpcode* result) {
|
||||
}
|
||||
|
||||
bool HloParser::ParseFftType(FftType* result) {
|
||||
VLOG(1) << "ParseFftType";
|
||||
VLOG(3) << "ParseFftType";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects fft type");
|
||||
}
|
||||
@ -3829,7 +3829,7 @@ bool HloParser::ParseComparisonDirection(ComparisonDirection* result) {
|
||||
}
|
||||
|
||||
bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
|
||||
VLOG(1) << "ParseFusionKind";
|
||||
VLOG(3) << "ParseFusionKind";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects fusion kind");
|
||||
}
|
||||
@ -3846,7 +3846,7 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
|
||||
}
|
||||
|
||||
bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
|
||||
VLOG(1) << "ParseRandomDistribution";
|
||||
VLOG(3) << "ParseRandomDistribution";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects random distribution");
|
||||
}
|
||||
@ -3863,7 +3863,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
|
||||
}
|
||||
|
||||
bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
|
||||
VLOG(1) << "ParsePrecision";
|
||||
VLOG(3) << "ParsePrecision";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects random distribution");
|
||||
}
|
||||
@ -3880,7 +3880,7 @@ bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
|
||||
}
|
||||
|
||||
bool HloParser::ParseInt64(int64* result) {
|
||||
VLOG(1) << "ParseInt64";
|
||||
VLOG(3) << "ParseInt64";
|
||||
if (lexer_.GetKind() != TokKind::kInt) {
|
||||
return TokenError("expects integer");
|
||||
}
|
||||
@ -3969,7 +3969,7 @@ bool HloParser::ParseBool(bool* result) {
|
||||
}
|
||||
|
||||
bool HloParser::ParseToken(TokKind kind, const string& msg) {
|
||||
VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg;
|
||||
VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
|
||||
if (lexer_.GetKind() != kind) {
|
||||
return TokenError(msg);
|
||||
}
|
||||
|
@ -41,6 +41,8 @@ class HloPassInterface {
|
||||
// module group. Ideally, the module group variant would be named "Run" as
|
||||
// well, but C++ does not handle overloaded virtual methods well.
|
||||
virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0;
|
||||
|
||||
virtual bool IsPassPipeline() { return false; }
|
||||
};
|
||||
|
||||
// Base class for passes which are module-scoped.
|
||||
|
@ -58,14 +58,21 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal(
|
||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
|
||||
bool changed = false;
|
||||
for (HloPassInterface* pass : passes) {
|
||||
VLOG(1) << " HLO pass " << pass->name();
|
||||
absl::string_view pass_name = pass->name();
|
||||
VLOG(1) << " HLO pass " << pass_name;
|
||||
MaybeDumpHlo(*hlo,
|
||||
/*after_pass_name=*/last_pass_name,
|
||||
/*before_pass_name=*/pass->name());
|
||||
/*before_pass_name=*/pass_name);
|
||||
if (!pass->IsPassPipeline()) {
|
||||
compilation_stats_->StartPass(pass_name);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
|
||||
changed |= pass_changed;
|
||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
|
||||
last_pass_name = string(pass->name());
|
||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
|
||||
last_pass_name = string(pass_name);
|
||||
if (!pass->IsPassPipeline()) {
|
||||
compilation_stats_->EndPass(pass_name);
|
||||
}
|
||||
}
|
||||
MaybeDumpHlo(*hlo,
|
||||
/*after_pass_name=*/last_pass_name,
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/compilation_stats.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -34,7 +35,14 @@ namespace xla {
|
||||
// Pipeline of HLO passes.
|
||||
class HloPassPipeline : public HloPassInterface {
|
||||
public:
|
||||
explicit HloPassPipeline(const string& name) : name_(name) {}
|
||||
explicit HloPassPipeline(const string& name,
|
||||
CompilationStats* compilation_stats = nullptr)
|
||||
: name_(name), compilation_stats_(compilation_stats) {
|
||||
if (compilation_stats == nullptr) {
|
||||
empty_compilation_stats_ = CompilationStats::MakeNoopStats();
|
||||
compilation_stats_ = empty_compilation_stats_.get();
|
||||
}
|
||||
}
|
||||
absl::string_view name() const override { return name_; }
|
||||
|
||||
// Add a pass to the pipeline. It should be called with the arguments for the
|
||||
@ -65,6 +73,8 @@ class HloPassPipeline : public HloPassInterface {
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
|
||||
|
||||
bool IsPassPipeline() override { return true; }
|
||||
|
||||
private:
|
||||
// Returns the set of passes which are enabled. DebugOptions can selectively
|
||||
// disable passes via --xla_disable_hlo_passes flag.
|
||||
@ -105,6 +115,11 @@ class HloPassPipeline : public HloPassInterface {
|
||||
std::vector<std::unique_ptr<HloPassInterface>> passes_;
|
||||
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
|
||||
bool run_called_ = false;
|
||||
|
||||
CompilationStats* compilation_stats_;
|
||||
// Default stats instance for when one is not passed in the constructor.
|
||||
// Use via compilation_stats_, not directly.
|
||||
std::unique_ptr<CompilationStats> empty_compilation_stats_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -35,7 +35,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
|
||||
const ShapeIndex& index) {
|
||||
BufferAllocation::Slice buffer_slice;
|
||||
if (hlo.opcode() == HloOpcode::kParameter &&
|
||||
hlo.parent() == hlo.parent()->parent()->entry_computation()) {
|
||||
hlo.parent() == module_.entry_computation()) {
|
||||
// Entry computation parameters may alias with each other but may not alias
|
||||
// with our temporary buffers.
|
||||
buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0);
|
||||
@ -78,12 +78,9 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
|
||||
.xla_llvm_enable_invariant_load_metadata()) {
|
||||
// Parameters of the entry computation are never stored to, loading from a
|
||||
// parameter pointer should always return the same result within a loop.
|
||||
if (hlo.opcode() == HloOpcode::kParameter) {
|
||||
const std::vector<HloInstruction*>& parameter_instructions =
|
||||
module_.entry_computation()->parameter_instructions();
|
||||
if (absl::c_linear_search(parameter_instructions, &hlo)) {
|
||||
array->MarkInvariantOverWholeProgram(context_);
|
||||
}
|
||||
if (hlo.opcode() == HloOpcode::kParameter &&
|
||||
hlo.parent() == module_.entry_computation()) {
|
||||
array->MarkInvariantOverWholeProgram(context_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2544,7 +2544,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
};
|
||||
|
||||
// Check the shapes of computation parameters and return types.
|
||||
if (!ShapeUtil::Equal(condition.result(), ShapeUtil::MakeShape(PRED, {}))) {
|
||||
if (!ShapeUtil::Compatible(condition.result(),
|
||||
ShapeUtil::MakeShape(PRED, {}))) {
|
||||
return InvalidArgument("Condition must return a boolean; got %s.",
|
||||
shape_string());
|
||||
}
|
||||
@ -2564,8 +2565,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
const Shape& branch_index,
|
||||
absl::Span<const ProgramShape> branch_computations,
|
||||
absl::Span<const Shape> branch_operands) {
|
||||
if (!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
|
||||
!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(S32, {}))) {
|
||||
if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
|
||||
!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
|
||||
return InvalidArgument("branch_index must be bool or int32; got %s.",
|
||||
ShapeUtil::HumanString(branch_index));
|
||||
}
|
||||
|
@ -285,7 +285,7 @@ Status TransferManager::TransferBufferFromDevice(
|
||||
void* destination) {
|
||||
if (source.size() < size) {
|
||||
return FailedPrecondition(
|
||||
"Source allocation on device not large enough for data tranfer: "
|
||||
"Source allocation on device not large enough for data transfer: "
|
||||
"%d < %d",
|
||||
source.size(), size);
|
||||
}
|
||||
@ -298,7 +298,7 @@ Status TransferManager::TransferBufferToDevice(
|
||||
se::DeviceMemoryBase* destination) {
|
||||
if (destination->size() < size) {
|
||||
return FailedPrecondition(
|
||||
"Destination allocation on device not large enough for data tranfer: "
|
||||
"Destination allocation on device not large enough for data transfer: "
|
||||
"%d < %d",
|
||||
destination->size(), size);
|
||||
}
|
||||
@ -336,4 +336,9 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
|
||||
return std::move(shaped_buffer);
|
||||
}
|
||||
|
||||
StatusOr<Shape> TransferManager::ChooseCompactLayoutForShape(
|
||||
const Shape& host_shape) const {
|
||||
return LayoutUtil::GetWithDefaultLayout(host_shape);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -216,6 +216,15 @@ class TransferManager {
|
||||
// region for a host-to-device transfer.
|
||||
virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0;
|
||||
|
||||
// Chooses a compact layout for 'shape', ignoring any existing layout on
|
||||
// 'shape'. What "reasonable" means is left up to the backend. The
|
||||
// intended use case is to choose a layout that avoids excessive padding on
|
||||
// devices that have tiled memory architectures.
|
||||
// The default implementation always picks a default (major-to-minor) layout.
|
||||
// Fails if 'shape' cannot be represented by the device.
|
||||
virtual StatusOr<Shape> ChooseCompactLayoutForShape(
|
||||
const Shape& host_shape) const;
|
||||
|
||||
// Allocates a ScopedShapedBuffer which can hold data with the given on-host
|
||||
// shape. The on-device shape may be different as indicated by
|
||||
// HostShapeToDeviceShape.
|
||||
|
@ -200,6 +200,12 @@ class Shape {
|
||||
bool operator==(const Shape& other) const { return Equal()(*this, other); }
|
||||
bool operator!=(const Shape& other) const { return !(*this == other); }
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Shape& s) {
|
||||
return H::combine(std::move(h), s.element_type_, s.dimensions_,
|
||||
s.dynamic_dimensions_, s.tuple_shapes_, s.layout_);
|
||||
}
|
||||
|
||||
private:
|
||||
// The element type of this shape (tuple, array, etc).
|
||||
PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "absl/hash/hash_testing.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
@ -210,5 +212,11 @@ TEST_F(ShapeTest, ProgramShapeToString) {
|
||||
prog.ToString());
|
||||
}
|
||||
|
||||
TEST_F(ShapeTest, SupportsAbslHash) {
|
||||
EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly(
|
||||
{opaque_, token_, scalar_, scalar_with_tile_, matrix_, matrix2_, tuple_,
|
||||
nested_tuple_, dyanmic_matrix_}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <new>
|
||||
@ -42,7 +43,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
@ -604,10 +604,9 @@ std::unique_ptr<HloComputation> MakeReduceTestComputation() {
|
||||
|
||||
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {32}), 0));
|
||||
auto const1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
|
||||
auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
@ -618,7 +617,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
|
||||
HloInstruction::FusionKind::kInput);
|
||||
|
||||
EXPECT_TRUE(
|
||||
LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(15),
|
||||
LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(496),
|
||||
ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||
}
|
||||
|
||||
@ -896,8 +895,7 @@ void BM_ParallelFusion(int num_iters) {
|
||||
// Initialize thread pool.
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||
intra_op_parallelism_threads);
|
||||
tensorflow::EigenThreadPoolWrapper tp(&pool);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
|
||||
|
||||
// Initialize ExecutableRunOptions.
|
||||
ExecutableRunOptions options;
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/byte_order.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
@ -108,12 +107,10 @@ struct LocalClientTestBase::EigenThreadPoolWrapper {
|
||||
explicit EigenThreadPoolWrapper()
|
||||
: pool(new tensorflow::thread::ThreadPool(
|
||||
tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
|
||||
wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
|
||||
device(new Eigen::ThreadPoolDevice(wrapper.get(),
|
||||
wrapper->NumThreads())) {}
|
||||
device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
|
||||
pool->NumThreads())) {}
|
||||
|
||||
std::unique_ptr<tensorflow::thread::ThreadPool> pool;
|
||||
std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper;
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> device;
|
||||
};
|
||||
|
||||
|
@ -295,255 +295,191 @@ XLA_TEST_F(MultiOutputFusionTest,
|
||||
DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) {
|
||||
const string testcase = absl::StrCat(kScalarOps, R"(
|
||||
fused_reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
c0 = f32[] constant(0)
|
||||
r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
|
||||
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
|
||||
r1 = f32[32,32]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
|
||||
mul = f32[32,32,32]{2,1,0} multiply(p0, p0)
|
||||
c1 = f32[] constant(5)
|
||||
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
|
||||
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
|
||||
r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
|
||||
ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p = f32[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
|
||||
calls=fused_reduce
|
||||
p = f32[32,32,32]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(p), kind=kInput,
|
||||
calls=fused_reduce
|
||||
})");
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
|
||||
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
|
||||
result));
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
|
||||
}
|
||||
|
||||
XLA_TEST_F(MultiOutputFusionTest,
|
||||
DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) {
|
||||
const string testcase = absl::StrCat(kScalarOps, R"(
|
||||
fused_reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0 = f32[32,32,32]{2,1,0} parameter(0)
|
||||
c0 = f32[] constant(0)
|
||||
r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
|
||||
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
|
||||
r1 = f32[32,32]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
|
||||
mul = f32[32,32,32]{2,1,0} multiply(p0, p0)
|
||||
c1 = f32[] constant(5)
|
||||
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
|
||||
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
|
||||
r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
|
||||
ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p = f32[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
|
||||
calls=fused_reduce
|
||||
p = f32[32,32,32]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(p), kind=kInput,
|
||||
calls=fused_reduce
|
||||
})");
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
|
||||
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
|
||||
result));
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
|
||||
}
|
||||
|
||||
XLA_TEST_F(MultiOutputFusionTest,
|
||||
DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) {
|
||||
const string testcase = absl::StrCat(kScalarOps, R"(
|
||||
fused_reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0 = f32[2,32,32]{2,1,0} parameter(0)
|
||||
c0 = f32[] constant(0)
|
||||
r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
|
||||
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
|
||||
r1 = f32[32]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
|
||||
mul = f32[2,32,32]{2,1,0} multiply(p0, p0)
|
||||
c1 = f32[] constant(1.17549e-38)
|
||||
r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max
|
||||
r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add
|
||||
ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3)
|
||||
r2 = f32[32]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max
|
||||
r3 = f32[32]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add
|
||||
ROOT tuple = (f32[32]{0}, f32[32]{0}, f32[32]{0}) tuple(r1, r2, r3)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p = f32[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
|
||||
calls=fused_reduce
|
||||
p = f32[2,32,32]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[32]{0}, f32[32]{0}, f32[32]{0}) fusion(p), kind=kInput,
|
||||
calls=fused_reduce
|
||||
})");
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
|
||||
LiteralUtil::CreateR1<float>({36, 64}),
|
||||
LiteralUtil::CreateR1<float>({66, 138})),
|
||||
result));
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
|
||||
}
|
||||
|
||||
XLA_TEST_F(MultiOutputFusionTest,
|
||||
DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) {
|
||||
const string testcase = absl::StrCat(kScalarOps, R"(
|
||||
fused_reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0 = f32[2,32,32]{2,1,0} parameter(0)
|
||||
c0 = f32[] constant(0)
|
||||
r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
|
||||
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
|
||||
r1 = f32[2,32]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
|
||||
mul = f32[2,32,32]{2,1,0} multiply(p0, p0)
|
||||
c1 = f32[] constant(5)
|
||||
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
|
||||
ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0})
|
||||
r2 = f32[2,32]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
|
||||
ROOT tuple = (f32[2,32,32]{2,1,0}, f32[2,32]{1,0}, f32[2,2]{1,0})
|
||||
tuple(p0, r1, r2)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p = f32[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
|
||||
kind=kInput, calls=fused_reduce
|
||||
p = f32[2,32,32]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[2,32,32]{2,1,0}, f32[2,32]{1,0}, f32[2,32]{1,0})
|
||||
fusion(p), kind=kInput, calls=fused_reduce
|
||||
})");
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
|
||||
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
|
||||
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
|
||||
result));
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
|
||||
}
|
||||
|
||||
XLA_TEST_F(MultiOutputFusionTest,
|
||||
DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) {
|
||||
const string testcase = absl::StrCat(kScalarOps, R"(
|
||||
fused_reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0 = f32[32,32,2]{2,1,0} parameter(0)
|
||||
c0 = f32[] constant(0)
|
||||
r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
|
||||
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
|
||||
r1 = f32[32,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
|
||||
mul = f32[32,32,2]{2,1,0} multiply(p0, p0)
|
||||
c1 = f32[] constant(5)
|
||||
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
|
||||
ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0})
|
||||
r2 = f32[32,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
|
||||
ROOT tuple = (f32[32,2]{1,0}, f32[32,32,2]{2,1,0}, f32[32,2]{1,0})
|
||||
tuple(r1, mul, r2)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p = f32[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
|
||||
kind=kInput, calls=fused_reduce
|
||||
p = f32[32,32,2]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[32,2]{1,0}, f32[32,32,2]{2,1,0}, f32[32,2]{1,0})
|
||||
fusion(p), kind=kInput, calls=fused_reduce
|
||||
})");
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
|
||||
LiteralUtil::CreateR3<float>(
|
||||
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
|
||||
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
|
||||
result));
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
|
||||
}
|
||||
|
||||
XLA_TEST_F(MultiOutputFusionTest,
|
||||
DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) {
|
||||
const string testcase = absl::StrCat(kScalarOps, R"(
|
||||
const string testcase = R"(
|
||||
HloModule m, is_scheduled=true
|
||||
|
||||
Add {
|
||||
lhsadd = f32[] parameter(0)
|
||||
rhsadd = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhsadd, rhsadd)
|
||||
}
|
||||
fused_reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0 = f32[2,32,32]{2,1,0} parameter(0)
|
||||
c0 = f32[] constant(0)
|
||||
r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
|
||||
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
|
||||
r1 = f32[32]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
|
||||
mul = f32[2,32,32]{2,1,0} multiply(p0, p0)
|
||||
c1 = f32[] constant(5)
|
||||
b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={}
|
||||
mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1)
|
||||
ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0})
|
||||
tuple(r1, mul, mul2)
|
||||
b1 = f32[2,32,32]{2,1,0} broadcast(c1), dimensions={}
|
||||
mul2 = f32[2,32,32]{2,1,0} multiply(p0, b1)
|
||||
ROOT tuple = (f32[32]{0}, f32[2,32,32]{2,1,0}, f32[2,32,32]{2,1,0})
|
||||
tuple(r1, mul, mul2)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p = f32[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
|
||||
kind=kInput, calls=fused_reduce
|
||||
})");
|
||||
p = f32[2,32,32]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[32]{0}, f32[2,32,32]{2,1,0}, f32[2,32,32]{2,1,0})
|
||||
fusion(p), kind=kInput, calls=fused_reduce
|
||||
})";
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR1<float>({14, 22}),
|
||||
LiteralUtil::CreateR3<float>(
|
||||
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
|
||||
LiteralUtil::CreateR3<float>(
|
||||
{{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
|
||||
result));
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
|
||||
}
|
||||
|
||||
XLA_TEST_F(MultiOutputFusionTest,
|
||||
DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) {
|
||||
const string testcase = absl::StrCat(kScalarOps, R"(
|
||||
fused_reduce {
|
||||
p0 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p0 = f32[2,32,32]{2,1,0} parameter(0)
|
||||
init1 = f32[] parameter(1)
|
||||
init2 = f32[] parameter(2)
|
||||
r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add
|
||||
r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max
|
||||
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
|
||||
r1 = f32[2,32]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add
|
||||
r2 = f32[2,32]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max
|
||||
ROOT tuple = (f32[2,32]{1,0}, f32[2,32]{1,0}) tuple(r1, r2)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p = f32[2,2,2]{2,1,0} parameter(0)
|
||||
p = f32[2,32,32]{2,1,0} parameter(0)
|
||||
i = f32[] parameter(1)
|
||||
j = f32[] parameter(2)
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
|
||||
calls=fused_reduce
|
||||
ROOT fusion = (f32[2,32]{1,0}, f32[2,32]{1,0}) fusion(p, i, j),
|
||||
kind=kInput, calls=fused_reduce
|
||||
})");
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
auto init1 = LiteralUtil::CreateR0<float>(5);
|
||||
auto init2 = LiteralUtil::CreateR0<float>(6);
|
||||
Literal result =
|
||||
ExecuteNoHloPasses(std::move(module), {¶m, &init1, &init2});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
|
||||
LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
|
||||
result));
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
|
||||
}
|
||||
|
||||
XLA_TEST_F(MultiOutputFusionTest,
|
||||
DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) {
|
||||
const string testcase = absl::StrCat(kScalarOps, R"(
|
||||
fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) {
|
||||
p0 = f16[2,2,2]{2,1,0} parameter(0)
|
||||
convert = f32[2,2,2]{2,1,0} convert(p0)
|
||||
fused_reduce (p0: f16[2,32,32]) -> (f32[2,32], f32[2,32], f16[2,32,32]) {
|
||||
p0 = f16[2,32,32]{2,1,0} parameter(0)
|
||||
convert = f32[2,32,32]{2,1,0} convert(p0)
|
||||
c0 = f32[] constant(0)
|
||||
r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add
|
||||
mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
|
||||
r1 = f32[2,32]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add
|
||||
mul = f32[2,32,32]{2,1,0} multiply(convert, convert)
|
||||
c1 = f32[] constant(5)
|
||||
r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
|
||||
ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0})
|
||||
r2 = f32[2,32]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
|
||||
ROOT tuple = (f32[2,32]{1,0}, f32[2,32]{1,0}, f16[2,32,32]{2,1,0})
|
||||
tuple(r1, r2, p0)
|
||||
}
|
||||
|
||||
ENTRY reduce {
|
||||
p = f16[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p),
|
||||
p = f16[2,32,32]{2,1,0} parameter(0)
|
||||
ROOT fusion = (f32[2,32]{1,0}, f32[2,32]{1,0}, f16[2,32,32]{2,1,0}) fusion(p),
|
||||
kind=kInput, calls=fused_reduce
|
||||
})");
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param = LiteralUtil::CreateR3<Eigen::half>(
|
||||
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
|
||||
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
|
||||
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
|
||||
LiteralUtil::CreateR3<Eigen::half>(
|
||||
{{{Eigen::half(1), Eigen::half(2)},
|
||||
{Eigen::half(3), Eigen::half(4)}},
|
||||
{{Eigen::half(5), Eigen::half(6)},
|
||||
{Eigen::half(7), Eigen::half(8)}}})),
|
||||
result));
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -6,31 +6,18 @@ package(
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"py_test",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
|
||||
py_library(
|
||||
name = "batch_py",
|
||||
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:batch_ops",
|
||||
"//tensorflow/python:batch_ops_gen",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,14 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import gen_batch_ops
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_batch_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.ops.batch_ops import batch
|
||||
from tensorflow.python.ops.batch_ops import batch_function
|
||||
from tensorflow.python.ops.batch_ops import unbatch
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
||||
@ops.RegisterGradient("Batch")
|
||||
@ -55,85 +54,6 @@ def _UnbatchGrad(op, grad): # pylint: disable=invalid-name
|
||||
]
|
||||
|
||||
|
||||
def batch_function(num_batch_threads,
|
||||
max_batch_size,
|
||||
batch_timeout_micros,
|
||||
allowed_batch_sizes=None,
|
||||
max_enqueued_batches=10):
|
||||
"""Batches the computation done by the decorated function.
|
||||
|
||||
So, for example, in the following code
|
||||
|
||||
```python
|
||||
@batch_function(1, 2, 3)
|
||||
def layer(a):
|
||||
return tf.matmul(a, a)
|
||||
|
||||
b = layer(w)
|
||||
```
|
||||
|
||||
if more than one session.run call is simultaneously trying to compute `b`
|
||||
the values of `w` will be gathered, non-deterministically concatenated
|
||||
along the first axis, and only one thread will run the computation. See the
|
||||
documentation of the `Batch` op for more details.
|
||||
|
||||
Assumes that all arguments of the decorated function are Tensors which will
|
||||
be batched along their first dimension.
|
||||
|
||||
SparseTensor is not supported. The return value of the decorated function
|
||||
must be a Tensor or a list/tuple of Tensors.
|
||||
|
||||
Args:
|
||||
num_batch_threads: Number of scheduling threads for processing batches
|
||||
of work. Determines the number of batches processed in parallel.
|
||||
max_batch_size: Batch sizes will never be bigger than this.
|
||||
batch_timeout_micros: Maximum number of microseconds to wait before
|
||||
outputting an incomplete batch.
|
||||
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
|
||||
does nothing. Otherwise, supplies a list of batch sizes, causing the op
|
||||
to pad batches up to one of those sizes. The entries must increase
|
||||
monotonically, and the final entry must equal max_batch_size.
|
||||
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
The decorated function will return the unbatched computation output Tensors.
|
||||
"""
|
||||
|
||||
def decorator(fn): # pylint: disable=missing-docstring
|
||||
|
||||
def decorated(*args): # pylint: disable=missing-docstring
|
||||
|
||||
@function.defun(autograph=False)
|
||||
def computation(*computation_args):
|
||||
return fn(*computation_args)
|
||||
|
||||
computation = computation.get_concrete_function(
|
||||
*[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i))
|
||||
for i, x in enumerate(args)])
|
||||
|
||||
with ops.name_scope("batch") as name:
|
||||
for a in args:
|
||||
if not isinstance(a, ops.Tensor):
|
||||
raise ValueError("All arguments to functions decorated with "
|
||||
"`batch_function` are supposed to be Tensors; "
|
||||
"found %s" % repr(a))
|
||||
return gen_batch_ops.batch_function(
|
||||
num_batch_threads=num_batch_threads,
|
||||
max_batch_size=max_batch_size,
|
||||
batch_timeout_micros=batch_timeout_micros,
|
||||
allowed_batch_sizes=allowed_batch_sizes,
|
||||
max_enqueued_batches=max_enqueued_batches,
|
||||
shared_name=name,
|
||||
f=computation,
|
||||
in_tensors=list(args),
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.dtype for o in computation.outputs])
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def batch_function_v1(num_batch_threads,
|
||||
max_batch_size,
|
||||
batch_timeout_micros,
|
||||
|
@ -23,12 +23,8 @@ import time
|
||||
|
||||
from tensorflow.contrib.batching.python.ops import batch_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework.errors import InvalidArgumentError
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_batch_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -41,153 +37,6 @@ def delayed_plus1(x):
|
||||
class BatchOpsTest(test.TestCase):
|
||||
"""Tests for batch_ops.{un,}batch."""
|
||||
|
||||
def testBasicBatch(self):
|
||||
"""Tests that a single batched tensor executes together and only once."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=36000000, grad_timeout_micros=0,
|
||||
batching_queue="")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(
|
||||
sess.run([batched, index], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([batched, index], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
|
||||
# At this point either the thread or the main did the batch and the other
|
||||
# should have empty results.
|
||||
if list(thread_results[0][0]):
|
||||
batch_t = thread_results[0][0]
|
||||
index_t = thread_results[1]
|
||||
empty_b = main_results[0][0]
|
||||
empty_m = main_results[1]
|
||||
else:
|
||||
batch_t = main_results[0][0]
|
||||
index_t = main_results[1]
|
||||
empty_b = thread_results[0][0]
|
||||
empty_m = thread_results[1]
|
||||
|
||||
# Check that both the inputs made it out exactly once.
|
||||
self.assertAllEqual(sorted(batch_t), (1, 2))
|
||||
# Check that we get 2 rows in the index tensor.
|
||||
self.assertEqual(len(index_t), 2)
|
||||
# Check that the other ones are empty.
|
||||
self.assertEqual(len(empty_b), 0)
|
||||
self.assertEqual(len(empty_m), 0)
|
||||
|
||||
def testBatchWithPadding(self):
|
||||
"""Test that batching with padding up to an allowed batch size works."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[5, 10],
|
||||
grad_timeout_micros=0, batching_queue="")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(
|
||||
sess.run([batched, index], feed_dict={inp: [1, 3]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([batched, index], feed_dict={inp: [2, 4]})
|
||||
worker_thread.join()
|
||||
|
||||
# At this point either the thread or the main did the batch and the other
|
||||
# should have empty results.
|
||||
if list(thread_results[0][0]):
|
||||
batch_t = thread_results[0][0]
|
||||
else:
|
||||
batch_t = main_results[0][0]
|
||||
|
||||
# Check that the batch tensor incorporates the padding.
|
||||
self.assertEqual(len(batch_t), 5)
|
||||
|
||||
def testMultipleBatch(self):
|
||||
"""Tests that multiple batched tensors execute together."""
|
||||
with self.cached_session() as sess:
|
||||
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, _, _ = batch_ops.batch(
|
||||
[inp0, inp1],
|
||||
num_batch_threads=1,
|
||||
max_batch_size=2,
|
||||
batch_timeout_micros=36000000,
|
||||
grad_timeout_micros=0,
|
||||
batching_queue="")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(
|
||||
sess.run([batched], feed_dict={inp0: [1],
|
||||
inp1: [2]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]})
|
||||
worker_thread.join()
|
||||
|
||||
# At this point either the thread or the main did the batch and the other
|
||||
# should have empty results.
|
||||
if list(thread_results[0][0]):
|
||||
batch_t = thread_results[0]
|
||||
empty_t = main_results[0]
|
||||
else:
|
||||
batch_t = main_results[0]
|
||||
empty_t = thread_results[0]
|
||||
|
||||
# Assert that the tensors were batched together.
|
||||
self.assertAllEqual(sorted(batch_t[0]), [1, 2])
|
||||
self.assertAllEqual(sorted(batch_t[1]), [2, 3])
|
||||
self.assertAllEqual(empty_t[0], [])
|
||||
self.assertAllEqual(empty_t[1], [])
|
||||
|
||||
def testIllegalBatchDifferentDim0Sizes(self):
|
||||
"""Tests illegally feeding tensors with different dim0 sizes."""
|
||||
with self.cached_session() as sess:
|
||||
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp0, inp1], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="")
|
||||
with self.assertRaises(Exception) as raised:
|
||||
_ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]})
|
||||
self.assertGreater(
|
||||
raised.exception.message.find("must have equal 0th-dimension size"),
|
||||
0)
|
||||
|
||||
def testBasicUnbatch(self):
|
||||
"""Tests that batch and unbatch work together."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, id_t = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[3, 10],
|
||||
grad_timeout_micros=0, batching_queue="")
|
||||
computation = batched[0] + 1
|
||||
result = batch_ops.unbatch(computation, index, id_t,
|
||||
timeout_micros=1000000, shared_name="unbatch")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBasicUnbatchV1Decorated(self):
|
||||
"""Tests that the batch_function_v1 decorator works."""
|
||||
with self.cached_session() as sess:
|
||||
@ -210,206 +59,6 @@ class BatchOpsTest(test.TestCase):
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBasicUnbatchDecorated(self):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
with self.cached_session() as sess:
|
||||
# TODO(apassos): Removing this line causes test flakiness! Ideally should
|
||||
# be investigated.
|
||||
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
self.assertTrue(in_t.shape is not None)
|
||||
return in_t + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchDecoratedWithCapturedInput(self):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
with self.cached_session() as sess:
|
||||
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
|
||||
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
return in_t + captured_inp0 - captured_inp1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchFunctionOp(self):
|
||||
"""Tests that the batch_function op works."""
|
||||
with self.cached_session() as sess:
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def computation(in_t):
|
||||
return in_t + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = gen_batch_ops.batch_function(
|
||||
[inp],
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000,
|
||||
Tout=[dtypes.int32],
|
||||
f=computation,
|
||||
captured_tensors=computation.captured_inputs)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchFunctionOpWithCapturedInput(self):
|
||||
"""Tests that batch_function op works with captured input."""
|
||||
with self.cached_session() as sess:
|
||||
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
|
||||
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def computation(inp):
|
||||
return inp + captured_inp0 - captured_inp1
|
||||
|
||||
result = gen_batch_ops.batch_function(
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[3, 10],
|
||||
batching_queue="",
|
||||
f=computation,
|
||||
in_tensors=[inp],
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.type for o in computation.definition.signature.output_arg])
|
||||
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchFunctionOpWithInputError(self):
|
||||
"""Tests that batch_function op works with error in the inputs."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
|
||||
@function.Defun(dtypes.int32, dtypes.int32)
|
||||
def computation(in0, in1):
|
||||
return in0 + in1
|
||||
|
||||
result = gen_batch_ops.batch_function(
|
||||
[inp], # computation actually expects 2 inputs.
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
batching_queue="",
|
||||
f=computation,
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.type for o in computation.definition.signature.output_arg])
|
||||
|
||||
with self.assertRaisesRegexp(InvalidArgumentError,
|
||||
".*2 arguments.*but 1.*"):
|
||||
sess.run([result], feed_dict={inp: [2]})
|
||||
|
||||
def testBasicUnbatchDecoratedWithReshape(self):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
with self.cached_session() as sess:
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
return array_ops.reshape(in_t, [-1]) + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [[2]]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testUnbatchTimeout(self):
|
||||
"""Tests that the unbatch timeout works."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, id_t = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=36000000, grad_timeout_micros=0,
|
||||
batching_queue="")
|
||||
computation = batched[0] + 1
|
||||
timeout_micros = 10
|
||||
result = batch_ops.unbatch(computation, index, id_t, timeout_micros,
|
||||
shared_name="shared_unbatch")
|
||||
# Set up a parallel pipeline that delays the computation, but uses the
|
||||
# same unbatch resource object as the non-delayed pipeline.
|
||||
computation_delayed = script_ops.py_func(delayed_plus1,
|
||||
[batched[0]],
|
||||
dtypes.int32)
|
||||
result_delayed = batch_ops.unbatch(computation_delayed,
|
||||
index,
|
||||
id_t,
|
||||
timeout_micros,
|
||||
shared_name="shared_unbatch")
|
||||
|
||||
thread_results = []
|
||||
def worker():
|
||||
# A first call using the non-delayed pipeline. The batcher will send an
|
||||
# empty tensor along the non-delayed pipeline.
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
time.sleep(0.1) # Ensure the thread's call starts first.
|
||||
# A second call using the delayed pipeline. The batcher will send the
|
||||
# batched tensor along the delayed pipeline, thus delaying the arrival of
|
||||
# the batched tensor at the unbatch op, relative to the empty tensor.
|
||||
#
|
||||
# TODO(olston, apassos): Avoid relying on the order in which the batch op
|
||||
# emits the empty tensor versus the batched one.
|
||||
_ = sess.run([result_delayed], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
# The thread's call should hit the timeout, and thus get 0 results.
|
||||
self.assertEqual(len(thread_results), 0)
|
||||
|
||||
def testUnbatchGrad(self):
|
||||
"""Tests that batch and unbatch are differentiable."""
|
||||
with self.cached_session() as sess:
|
||||
@ -434,6 +83,5 @@ class BatchOpsTest(test.TestCase):
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [4])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -108,8 +108,8 @@ class NumpyState(base.Trackable):
|
||||
except AttributeError:
|
||||
value = _NumpyWrapper(value)
|
||||
self._track_trackable(value, name=name, overwrite=True)
|
||||
elif (name not in ("_setattr_tracking", "_update_uid")
|
||||
and getattr(self, "_setattr_tracking", True)):
|
||||
elif (name not in ("_self_setattr_tracking", "_self_update_uid")
|
||||
and getattr(self, "_self_setattr_tracking", True)):
|
||||
# Mixing restore()-created attributes with user-added trackable
|
||||
# objects is tricky, since we can't use the `_lookup_dependency` trick to
|
||||
# re-create attributes (we might accidentally steal the restoration for
|
||||
@ -154,4 +154,3 @@ class _NumpyWrapper(core_python_state.PythonState):
|
||||
self.array = numpy.load(string_file, allow_pickle=False)
|
||||
finally:
|
||||
string_file.close()
|
||||
|
||||
|
@ -21,19 +21,17 @@ from __future__ import print_function
|
||||
import collections
|
||||
import itertools
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework.test_util import TensorFlowTestCase
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -44,7 +42,6 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
@ -339,7 +336,7 @@ def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs):
|
||||
return [dict(t) for t in {tuple(d.items()) for d in res}]
|
||||
|
||||
|
||||
class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def _test_training_helper(self,
|
||||
num_units,
|
||||
@ -382,12 +379,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_training(self, num_units, input_size, batch_size, time, num_layers,
|
||||
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
self._test_training_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
@ -406,13 +400,10 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_training_fp16(self, num_units, input_size, batch_size, time,
|
||||
num_layers, variable_seq_lengths, time_major,
|
||||
dynamic_shape_input):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
self._test_training_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
@ -433,12 +424,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
||||
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
with self.session(use_gpu=True) as sess:
|
||||
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
||||
sess,
|
||||
@ -465,13 +453,10 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_inference_fp16(self, num_units, input_size, batch_size, time,
|
||||
num_layers, variable_seq_lengths, time_major,
|
||||
dynamic_shape_input):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
with self.session(use_gpu=True) as sess:
|
||||
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
||||
sess,
|
||||
@ -502,14 +487,11 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
|
||||
num_layers, variable_seq_lengths, time_major,
|
||||
dynamic_shape_input):
|
||||
"""Validates that dropout does not affect Cudnn Rnn inference."""
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
# Hand-picked dropouts are used below (0. and 1.)
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(use_gpu=True, graph=g) as sess:
|
||||
@ -721,7 +703,7 @@ def RunGRU(sess,
|
||||
return outputs, cu_outputs, h, cu_h
|
||||
|
||||
|
||||
class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
class CudnnGRUTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def _test_training_helper(self,
|
||||
num_units,
|
||||
@ -764,12 +746,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_training(self, num_units, input_size, batch_size, time, num_layers,
|
||||
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
self._test_training_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
@ -788,13 +767,10 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_training_fp16(self, num_units, input_size, batch_size, time,
|
||||
num_layers, variable_seq_lengths, time_major,
|
||||
dynamic_shape_input):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
self._test_training_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
@ -815,12 +791,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
||||
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
with self.session(use_gpu=True) as sess:
|
||||
(outputs, cu_outputs, h, cu_h) = RunGRU(
|
||||
sess,
|
||||
@ -843,13 +816,10 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_inference_fp16(self, num_units, input_size, batch_size, time,
|
||||
num_layers, variable_seq_lengths, time_major,
|
||||
dynamic_shape_input):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
with self.session(use_gpu=True) as sess:
|
||||
(outputs, cu_outputs, h, cu_h) = RunGRU(
|
||||
sess,
|
||||
@ -875,15 +845,12 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
|
||||
num_layers, variable_seq_lengths, time_major,
|
||||
dynamic_shape_input):
|
||||
"""Validates that dropout does not affect Cudnn Rnn inference."""
|
||||
# Hand-picked dropouts are used below (0. and 1.)
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(use_gpu=True, graph=g) as sess:
|
||||
# 1st time w/o dropout.
|
||||
@ -919,7 +886,7 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertAllClose(cu_h[0], cu_h2[0])
|
||||
|
||||
|
||||
class CudnnParamsFormatConverterTest(TensorFlowTestCase,
|
||||
class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
"""Class for testing various format converters."""
|
||||
|
||||
@ -970,22 +937,16 @@ class CudnnParamsFormatConverterTest(TensorFlowTestCase,
|
||||
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
|
||||
c["input_size"], c["num_layers"])
|
||||
for c in NAMED_RNN_TESTCASES)
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_lstm(self, num_units, input_size, num_layers):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
self._test_lstm_helper(num_units, input_size, num_layers,
|
||||
cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
|
||||
|
||||
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
|
||||
c["input_size"], c["num_layers"])
|
||||
for c in NAMED_RNN_TESTCASES)
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_lstm_bidi(self, num_units, input_size, num_layers):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
self._test_lstm_helper(num_units, input_size, num_layers,
|
||||
cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
|
||||
|
||||
@ -1044,27 +1005,22 @@ class CudnnParamsFormatConverterTest(TensorFlowTestCase,
|
||||
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
|
||||
c["input_size"], c["num_layers"])
|
||||
for c in NAMED_RNN_TESTCASES)
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_gru(self, num_units, input_size, num_layers):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
self._test_gru_helper(num_units, input_size, num_layers,
|
||||
cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
|
||||
|
||||
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
|
||||
c["input_size"], c["num_layers"])
|
||||
for c in NAMED_RNN_TESTCASES)
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_gru_bidi(self, num_units, input_size, num_layers):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
self._test_gru_helper(num_units, input_size, num_layers,
|
||||
cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
|
||||
|
||||
|
||||
class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
class CudnnRnnSaveRestoreTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
"""Class for testing various Cudnn Rnn SaveableObjects."""
|
||||
|
||||
def _create_opaque_param(self,
|
||||
@ -1112,14 +1068,11 @@ class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
],
|
||||
"direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION]
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_save_restore_variable(self, rnn_mode, num_units, input_size,
|
||||
num_layers, direction):
|
||||
# Verify the restored opaque param, once converted to tf_canonical format,
|
||||
# is the same as the tf canonicals of the pre-restored param.
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
with self.session(use_gpu=True) as sess:
|
||||
opaque_param = self._create_opaque_param(rnn_mode, num_units, input_size,
|
||||
num_layers, direction)
|
||||
@ -1164,14 +1117,11 @@ class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase):
|
||||
],
|
||||
"direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION]
|
||||
}))
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_gpu_only
|
||||
def test_save_restore_multi_variables(self, rnn_mode, num_units, input_size,
|
||||
num_layers, direction):
|
||||
# Verify the restored opaque param, once converted to tf_canonical format,
|
||||
# is the same as the tf canonicals of the pre-restored param.
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
with self.session(use_gpu=True) as sess:
|
||||
opaque_params = []
|
||||
saveables = []
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user